diff options
Diffstat (limited to '')
-rw-r--r-- | libar2simplified_init_context.c | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/libar2simplified_init_context.c b/libar2simplified_init_context.c new file mode 100644 index 0000000..c7f8b69 --- /dev/null +++ b/libar2simplified_init_context.c @@ -0,0 +1,368 @@ +/* See LICENSE file for copyright and license details. */ +#include "common.h" +#include <pthread.h> +#include <semaphore.h> + + +struct user_data; + +struct thread_data { + size_t index; + struct user_data *master; + pthread_t thread; + sem_t semaphore; + int error; + void (*function)(void *data); + void *function_input; +}; + +struct user_data { + struct thread_data *threads; + size_t nthreads; + pthread_mutex_t mutex; + sem_t semaphore; + uint_least64_t *joined; + uint_least64_t resting[]; +}; + + +static void * +alignedalloc(size_t num, size_t size, size_t extra, size_t alignment) +{ + void *ptr; + int err; + if (num > (SIZE_MAX - extra) / size) { + errno = ENOMEM; + return NULL; + } + if (alignment < sizeof(void *)) + alignment = sizeof(void *); + err = posix_memalign(&ptr, alignment, num * size + extra); + if (err) { + errno = err; + return NULL; + } else { + return ptr; + } +} + + +static void * +allocate(size_t num, size_t size, size_t alignment, struct libar2_context *ctx) +{ + size_t pad = (alignment - ((2 * sizeof(size_t)) & (alignment - 1))) & (alignment - 1); + char *ptr = alignedalloc(num, size, pad + 2 * sizeof(size_t), alignment); + if (ptr) { + ptr = &ptr[pad]; + *(size_t *)ptr = pad; + ptr = &ptr[sizeof(size_t)]; + *(size_t *)ptr = num * size; + ptr = &ptr[sizeof(size_t)]; + } + (void) ctx; + return ptr; +} + + +static void +deallocate(void *ptr, struct libar2_context *ctx) +{ + char *p = ptr; + p -= sizeof(size_t); + libar2_erase(ptr, *(size_t *)p); + p -= sizeof(size_t); + p -= *(size_t *)p; + free(p); + (void) ctx; +} + + +static void * +thread_loop(void *data_) +{ + struct thread_data *data = data_; + int err; + + for (;;) { + if (sem_wait(&data->semaphore)) { + if (errno == EINTR) + continue; + data->error = errno; + return NULL; + } + + if (!data->function) { + data->error = ENOTRECOVERABLE; + return NULL; + } + data->function(data->function_input); + + err = pthread_mutex_lock(&data->master->mutex); + if (err) { + data->error = err; + return NULL; + } + data->master->resting[data->index / 64] |= (uint_least64_t)1 << (data->index % 64); + pthread_mutex_unlock(&data->master->mutex); + if (sem_post(&data->master->semaphore)) { + data->error = errno; + return NULL; + } + } +} + + +static int +run_thread(size_t index, void (*function)(void *arg), void *arg, struct libar2_context *ctx) +{ + struct user_data *data = ctx->user_data; + int err; + + err = pthread_mutex_lock(&data->mutex); + if (err) { + errno = err; + return -1; + } + data->resting[index / 64] ^= (uint_least64_t)1 << (index % 64); + pthread_mutex_unlock(&data->mutex); + + if (data->threads[index].error) { + errno = data->threads[index].error; + return -1; + } + + data->threads[index].function = function; + data->threads[index].function_input = arg; + if (sem_post(&data->threads[index].semaphore)) + return -1; + + return 0; +} + + +static int +destroy_thread_pool(struct libar2_context *ctx) +{ + struct user_data *data = ctx->user_data; + size_t i; + int ret = 0; + for (i = data->nthreads; i--;) + if (run_thread(i, pthread_exit, NULL, ctx)) + return -1; + for (i = data->nthreads; i--;) { + pthread_join(data->threads[i].thread, NULL); + sem_destroy(&data->threads[i].semaphore); + if (data->threads[i].error) + ret = data->threads[i].error; + } + free(data->threads); + sem_destroy(&data->semaphore); + pthread_mutex_destroy(&data->mutex); + free(data); + return ret; +} + + +static int +init_thread_pool(size_t desired, size_t *createdp, struct libar2_context *ctx) +{ + struct user_data *data; + int err; + size_t i, size; + long int nproc, nproc_limit; +#ifdef __linux__ + char path[sizeof("/sys/devices/system/cpu/cpu") + 3 * sizeof(nproc)]; +#endif +#ifdef _SC_SEM_VALUE_MAX + long int semlimit; +#endif + + if (desired < 2) { + *createdp = 0; + return 0; + } + + nproc = sysconf(_SC_NPROCESSORS_ONLN); +#ifdef __linux__ + if (nproc < 1) { + nproc_limit = desired > LONG_MAX ? LONG_MAX : (long int)desired; + for (nproc = 0; nproc < nproc_limit; nproc++) { + sprintf(path, "%s%li", "/sys/devices/system/cpu/cpu", nproc); + if (access(path, F_OK)) + break; + } + } +#endif + if (nproc < 1) + nproc = FALLBACK_NPROC; + +#ifdef _SC_SEM_VALUE_MAX + semlimit = sysconf(_SC_SEM_VALUE_MAX); + if (semlimit >= 1 && semlimit < nproc) + nproc = semlimit; +#endif + + if (nproc == 1) { + *createdp = 0; + return 0; + } + + desired = (size_t)nproc < desired ? (size_t)nproc : desired; + + if (desired > SIZE_MAX - 63 || (desired + 63) / 64 > SIZE_MAX / sizeof(uint_least64_t) / 2) { + errno = ENOMEM; + return -1; + } + size = (desired + 63) / 64; + size *= sizeof(uint_least64_t) * 2; + data = alignedalloc(1, offsetof(struct user_data, resting), size, ALIGNOF(struct user_data)); + memset(data, 0, offsetof(struct user_data, resting) + size); + data->joined = &data->resting[(desired + 63) / 64]; + ctx->user_data = data; + + *createdp = data->nthreads = desired; + + data->threads = alignedalloc(data->nthreads, sizeof(*data->threads), 0, ALIGNOF(struct thread_data)); + if (!data->threads) + return -1; + + err = pthread_mutex_init(&data->mutex, NULL); + if (err) { + free(data->threads); + return -1; + } + err = sem_init(&data->semaphore, 0, 0); + if (err) { + pthread_mutex_destroy(&data->mutex); + free(data->threads); + return -1; + } + + for (i = 0; i < data->nthreads; i++) { + memset(&data->threads[i], 0, sizeof(data->threads[i])); + data->threads[i].master = data; + data->threads[i].index = i; + data->resting[i / 64] |= (uint_least64_t)1 << (i % 64); + if (sem_init(&data->threads[i].semaphore, 0, 0)) { + err = errno; + goto fail_post_sem; + } + err = pthread_create(&data->threads[i].thread, NULL, thread_loop, &data->threads[i]); + if (err) { + sem_destroy(&data->threads[i].semaphore); + fail_post_sem: + data->nthreads = i; + destroy_thread_pool(ctx); + errno = err; + return -1; + } + } + + return 0; +} + + +#if defined(__GNUC__) +__attribute__((__const__)) +#endif +static size_t +lb(uint_least64_t x) +{ + size_t r = 0; + while (x > 1) { + x >>= 1; + r += 1; + } + return r; +} + +static size_t +await_threads(size_t *indices, size_t n, size_t require, struct libar2_context *ctx) +{ + struct user_data *data = ctx->user_data; + size_t ret = 0, i; + uint_least64_t one; + int err; + + memset(data->joined, 0, (data->nthreads + 63) / 64 * sizeof(*data->joined)); + + for (i = 0; i < data->nthreads; i += 64) { + for (;;) { + one = data->resting[i / 64]; + one ^= data->joined[i / 64]; + if (!one) + break; + one &= ~(one - 1); + data->joined[i / 64] |= one; + if (ret++ < n) + indices[ret - 1] = i + lb(one); + } + } + + for (;;) { + if (ret < require) { + if (sem_wait(&data->semaphore)) { + if (errno == EINTR) + continue; + return 0; + } + } else if (sem_trywait(&data->semaphore)) { + if (errno == EAGAIN) + break; + else + return 0; + } + + err = pthread_mutex_lock(&data->mutex); + if (err) { + errno = err; + return 0; + } + for (i = 0; i < data->nthreads; i += 64) { + one = data->resting[i / 64]; + one ^= data->joined[i / 64]; + if (!one) + continue; + one &= ~(one - 1); + data->joined[i / 64] |= one; + if (ret++ < n) + indices[ret - 1] = i + lb(one); + break; + } + pthread_mutex_unlock(&data->mutex); + } + + return ret; +} + + +static size_t +get_ready_threads(size_t *indices, size_t n, struct libar2_context *ctx) +{ + return await_threads(indices, n, 1, ctx); +} + + +static int +join_thread_pool(struct libar2_context *ctx) +{ + struct user_data *data = ctx->user_data; + if (await_threads(NULL, 0, data->nthreads, ctx)) + return 0; + destroy_thread_pool(ctx); + return -1; +} + + +void +libar2simplified_init_context(struct libar2_context *ctxp) +{ + memset(ctxp, 0, sizeof(*ctxp)); + ctxp->allocate = allocate; + ctxp->deallocate = deallocate; + ctxp->init_thread_pool = init_thread_pool; + ctxp->get_ready_threads = get_ready_threads; + ctxp->run_thread = run_thread; + ctxp->join_thread_pool = join_thread_pool; + ctxp->destroy_thread_pool = destroy_thread_pool; +} |