/* See LICENSE file for copyright and license details. */
#include "libsimple.h"
#ifndef TEST
struct memalloc_state {
int zero_init;
int if_zero;
int round_up_size;
int have_size;
size_t alignment;
size_t elem_size;
size_t size_prod;
};
static int
vmemalloc_parse_size_prod(struct memalloc_state *state, size_t n, size_t arg, va_list ap)
{
if (state->have_size++)
goto inval;
state->elem_size = arg;
if (n) {
for (n--; n--;) {
arg = va_arg(ap, size_t);
if (!state->elem_size)
continue;
if (arg > SIZE_MAX / state->elem_size) {
errno = ENOMEM;
return -1;
}
state->elem_size *= arg;
}
} else {
if (!arg)
goto inval;
for (;;) {
arg = va_arg(ap, size_t);
if (!arg)
break;
if (arg > SIZE_MAX / state->elem_size) {
errno = ENOMEM;
return -1;
}
state->elem_size *= arg;
}
}
return 0;
inval:
errno = EINVAL;
return -1;
}
static int
vmemalloc_parse_args(struct memalloc_state *state, size_t n, va_list ap)
{
enum libsimple_memalloc_option opt;
long int page_size;
va_list *subapp;
size_t arg;
for (;;) {
opt = va_arg(ap, enum libsimple_memalloc_option);
switch (opt) {
case LIBSIMPLE_MEMALLOC_END:
return 0;
case LIBSIMPLE_MEMALLOC_ZERO_INIT:
if (state->zero_init >= 0)
goto inval;
state->zero_init = 1;
break;
case LIBSIMPLE_MEMALLOC_CONDITIONAL_ZERO_INIT:
if (state->zero_init >= 0)
goto inval;
state->zero_init = va_arg(ap, int);
state->zero_init = !!state->zero_init;
break;
case LIBSIMPLE_MEMALLOC_UNIQUE_IF_ZERO:
case LIBSIMPLE_MEMALLOC_NULL_IF_ZERO:
if (state->if_zero >= 0)
goto inval;
state->if_zero = (opt == LIBSIMPLE_MEMALLOC_UNIQUE_IF_ZERO);
break;
case LIBSIMPLE_MEMALLOC_ALIGNMENT:
if (state->alignment)
goto inval;
state->alignment = va_arg(ap, size_t);
if (!state->alignment)
goto inval;
break;
case LIBSIMPLE_MEMALLOC_PAGE_ALIGNMENT:
if (state->alignment)
goto inval;
page_size = sysconf(_SC_PAGESIZE);
if (page_size <= 0)
return -1;
state->alignment = (size_t)page_size;
break;
case LIBSIMPLE_MEMALLOC_ROUND_UP_SIZE_TO_ALIGNMENT:
if (state->round_up_size++)
goto inval;
break;
case LIBSIMPLE_MEMALLOC_ELEMENT_SIZE:
if (state->elem_size)
goto inval;
state->elem_size = va_arg(ap, size_t);
if (!state->elem_size)
goto inval;
break;
case LIBSIMPLE_MEMALLOC_PRODUCT_SIZE:
arg = va_arg(ap, size_t);
if (vmemalloc_parse_size_prod(state, n, arg, ap))
return -1;
break;
case LIBSIMPLE_MEMALLOC_VA_PRODUCT_SIZE:
subapp = va_arg(ap, va_list *);
arg = va_arg(*subapp, size_t);
if (vmemalloc_parse_size_prod(state, n, arg, *subapp))
return -1;
break;
case LIBSIMPLE_MEMALLOC_1_VA_PRODUCT_SIZE:
arg = va_arg(ap, size_t);
subapp = va_arg(ap, va_list *);
if (vmemalloc_parse_size_prod(state, n, arg, *subapp))
return -1;
break;
case LIBSIMPLE_MEMALLOC_VA_LIST:
subapp = va_arg(ap, va_list *);
if (vmemalloc_parse_args(state, n, *subapp))
return -1;
break;
default:
goto inval;
}
}
return 0;
inval:
errno = EINVAL;
return -1;
}
void *
libsimple_vmemalloc(size_t n, va_list ap) /* TODO test ([v]{mem,array}alloc) */
{
struct memalloc_state state;
size_t misalignment, size;
void *ptr = NULL;
int saved_errno;
state.zero_init = -1;
state.if_zero = -1;
state.round_up_size = 0;
state.have_size = 0;
state.alignment = 0;
state.elem_size = 0;
state.size_prod = 1;
if (vmemalloc_parse_args(&state, n, ap))
return NULL;
state.elem_size = state.elem_size ? state.elem_size : 1;
state.zero_init = state.zero_init >= 0 ? state.zero_init : 0;
n = state.have_size ? state.size_prod : n;
if (state.elem_size > 1) {
if (n > SIZE_MAX / state.elem_size) {
errno = ENOMEM;
return NULL;
}
n *= state.elem_size;
}
if (state.round_up_size) {
if (!state.alignment) {
errno = EINVAL;
return NULL;
}
if ((misalignment = n % state.alignment))
n += state.alignment - misalignment;
}
if (!n && state.if_zero == 0)
return NULL;
n = n ? n : (state.if_zero > 0);
saved_errno = errno;
errno = 0;
if (state.alignment) {
if (state.alignment % sizeof(void *)) {
size = n;
if ((misalignment = size % state.alignment))
size += state.alignment - misalignment;
ptr = aligned_alloc(state.alignment, size);
} else {
errno = posix_memalign(&ptr, state.alignment, n);
}
if (ptr && state.zero_init)
memset(ptr, 0, n);
} else {
ptr = state.zero_init ? calloc(n, 1) : malloc(n);
}
if (!ptr && n) {
if (!errno)
errno = ENOMEM;
return NULL;
}
errno = errno ? errno : saved_errno;
return ptr;
}
#else
#include "test.h"
int
main(void)
{
return 0;
}
#endif