aboutsummaryrefslogblamecommitdiffstats
path: root/libar2simplified_hash.c
blob: bdc5dd23f6a71e2a48db6b4fb9daae92bf02b5c3 (plain) (tree)






























































































































































































































































































































































































                                                                                                             
/* See LICENSE file for copyright and license details. */
#include "common.h"
#include <pthread.h>
#include <semaphore.h>


struct thread_data {
	pthread_t thread;
	pthread_mutex_t mutex;
	sem_t semaphore;
	pthread_mutex_t *master_mutex;
	sem_t *master_semaphore;
	int *master_needs_a_thread;
	int error;
	void (*function)(void *data);
	void *function_input;
};

struct user_data {
	struct thread_data *threads;
	size_t nthreads;
	int need_a_thread;
	pthread_mutex_t master_mutex;
	sem_t master_semaphore;
};


static void *
alignedalloc(size_t num, size_t size, size_t alignment, size_t extra)
{
	void *ptr;
	int err;
	if (num > (SIZE_MAX - extra) / size) {
		errno = ENOMEM;
		return NULL;
	}
	if (alignment < sizeof(void *))
		alignment = sizeof(void *);
	err = posix_memalign(&ptr, alignment, num * size + extra);
	if (err) {
		errno = err;
		return NULL;
	} else {
		return ptr;
	}
}


static void *
allocate(size_t num, size_t size, size_t alignment, struct libar2_context *ctx)
{
	size_t pad = (alignment - ((2 * sizeof(size_t)) & (alignment - 1))) & (alignment - 1);
	char *ptr = alignedalloc(num, size, alignment, pad + 2 * sizeof(size_t));
	if (ptr) {
		ptr = &ptr[pad];
		*(size_t *)ptr = pad;
		ptr = &ptr[sizeof(size_t)];
		*(size_t *)ptr = num * size;
		ptr = &ptr[sizeof(size_t)];
	}
	(void) ctx;
	return ptr;
}


static void
deallocate(void *ptr, struct libar2_context *ctx)
{
	char *p = ptr;
	p -= sizeof(size_t);
	libar2_erase(ptr, *(size_t *)p);
	p -= sizeof(size_t);
	p -= *(size_t *)p;
	free(p);
	(void) ctx;
}


static void *
thread_loop(void *data_)
{
	struct thread_data *data = data_;
	int err;
	void (*function)(void *data);
	void *function_input;

	for (;;) {
		if (sem_wait(&data->semaphore)) {
			data->error = errno;
			return NULL;
		}

		err = pthread_mutex_lock(&data->mutex);
		if (err) {
			data->error = err;
			return NULL;
		}
		function_input = data->function_input;
		function = data->function;
		pthread_mutex_unlock(&data->mutex);

		if (function) {
			function(function_input);

			err = pthread_mutex_lock(data->master_mutex);
			if (err) {
				data->error = err;
				return NULL;
			}

			err = pthread_mutex_lock(&data->mutex);
			if (err) {
				pthread_mutex_unlock(data->master_mutex);
				data->error = err;
				return NULL;
			}
			data->function = NULL;
			data->function_input = NULL;
			pthread_mutex_unlock(&data->mutex);
			if (*data->master_needs_a_thread) {
				*data->master_needs_a_thread = 0;
				if (sem_post(data->master_semaphore)) {
					err = errno;
					pthread_mutex_unlock(data->master_mutex);
					data->error = err;
					return NULL;
				}
			}
			pthread_mutex_unlock(data->master_mutex);
		}
	}
}


static int
run_thread(size_t index, void (*function)(void *arg), void *arg, struct libar2_context *ctx)
{
	struct user_data *data = ctx->user_data;
	int err;
	err = pthread_mutex_lock(&data->threads[index].mutex);
	if (err) {
		errno = err;
		return -1;
	}
	if (data->threads[index].error) {
		err = data->threads[index].error;
		pthread_mutex_unlock(&data->threads[index].mutex);
		errno = err;
		return -1;
	}
	data->threads[index].function_input = arg;
	data->threads[index].function = function;
	if (sem_post(&data->threads[index].semaphore)) {
		return -1;
	}
	pthread_mutex_unlock(&data->threads[index].mutex);
	return 0;
}


static int
destroy_thread_pool(struct libar2_context *ctx)
{
	struct user_data *data = ctx->user_data;
	size_t i;
	int ret = 0, err;
	for (i = data->nthreads; i--;)
		if (run_thread(i, pthread_exit, NULL, ctx))
			return -1;
	for (i = data->nthreads; i--;) {
		pthread_join(data->threads[i].thread, NULL);
		err = pthread_mutex_lock(&data->threads[i].mutex);
		if (err)
			ret = err;
		sem_destroy(&data->threads[i].semaphore);
		if (data->threads[i].error)
			ret = data->threads[i].error;
		pthread_mutex_unlock(&data->threads[i].mutex);
		pthread_mutex_destroy(&data->threads[i].mutex);
	}
	free(data->threads);
	sem_destroy(&data->master_semaphore);
	pthread_mutex_destroy(&data->master_mutex);
	return ret;
}


static int
init_thread_pool(size_t desired, size_t *createdp, struct libar2_context *ctx)
{
	struct user_data *data = ctx->user_data;
	int err;
	size_t i;
	long int nproc, nproc_limit;
#ifdef __linux__
	char path[sizeof("/sys/devices/system/cpu/cpu") + 3 * sizeof(nproc)];
#endif

#ifdef TODO
	if (desired < 2) {
		*createdp = 0;
		return 0;
	}
#endif

	nproc = sysconf(_SC_NPROCESSORS_ONLN);
#ifdef __linux__
	if (nproc < 1) {
		nproc_limit = desired > LONG_MAX ? LONG_MAX : (long int)desired;
		for (nproc = 0; nproc < nproc_limit; nproc++) {
			sprintf(path, "%s%li", "/sys/devices/system/cpu/cpu", nproc);
			if (access(path, F_OK))
				break;
		}
	}
#endif
	if (nproc < 1)
		nproc = FALLBACK_NPROC;

	if (nproc == 1) {
		*createdp = 0;
		return 0;
	}

	data->nthreads = (size_t)nproc < desired ? (size_t)nproc : desired;
	*createdp = data->nthreads;

	data->threads = alignedalloc(data->nthreads, sizeof(*data->threads), ALIGNOF(struct thread_data), 0);
	if (!data->threads)
		return -1;

	err = pthread_mutex_init(&data->master_mutex, NULL);
	if (err) {
		free(data->threads);
		return -1;
	}
	err = sem_init(&data->master_semaphore, 0, 0);
	if (err) {
		pthread_mutex_destroy(&data->master_mutex);
		free(data->threads);
		return -1;
	}
	data->need_a_thread = 0;

	for (i = 0; i < data->nthreads; i++) {
		memset(&data->threads[i], 0, sizeof(data->threads[i]));
		data->threads[i].master_mutex = &data->master_mutex;
		data->threads[i].master_semaphore = &data->master_semaphore;
		data->threads[i].master_needs_a_thread = &data->need_a_thread;
		err = pthread_mutex_init(&data->threads[i].mutex, NULL);
		if (err)
			goto fail_post_mutex;
		if (sem_init(&data->threads[i].semaphore, 0, 0)) {
			err = errno;
			goto fail_post_cond;
		}
		err = pthread_create(&data->threads[i].thread, NULL, thread_loop, &data->threads[i]);
		if (err) {
			sem_destroy(&data->threads[i].semaphore);
		fail_post_cond:
			pthread_mutex_destroy(&data->threads[i].mutex);
		fail_post_mutex:
			data->nthreads = i;
			destroy_thread_pool(ctx);
			errno = err;
			return -1;
		}
	}

	return 0;
}


static int
set_need_a_thread(struct user_data *data, int need)
{
	int err;
	err = pthread_mutex_lock(&data->master_mutex);
	if (err) {
		errno = err;
		return -1;
	}
	data->need_a_thread = need;
	pthread_mutex_unlock(&data->master_mutex);
	return 0;
}


static int
await_some_thread(struct user_data *data)
{
	int err, need_a_thread;
	err = pthread_mutex_lock(&data->master_mutex);
	if (err) {
		errno = err;
		return -1;
	}
	need_a_thread = data->need_a_thread;
	pthread_mutex_unlock(&data->master_mutex);
	if (need_a_thread) {
		if (sem_wait(&data->master_semaphore)) {
			err = errno;
			pthread_mutex_unlock(&data->master_mutex);
			errno = err;
			return -1;
		}
	}
	return 0;
}


static size_t
await_threads(size_t *indices, size_t n, size_t require, struct libar2_context *ctx)
{
	struct user_data *data = ctx->user_data;
	size_t i, ret = 0, first = 0;
	int err;
	for (;;) {
		if (set_need_a_thread(data, 1))
			return 0;
		for (i = first; i < data->nthreads; i++) {
			err = pthread_mutex_lock(&data->threads[i].mutex);
			if (err) {
				errno = err;
				return 0;
			}
			if (!data->threads[i].function) {
				if (ret++ < n)
					indices[ret - 1] = i;
				first += (i == first);
			}
			if (data->threads[i].error) {
				errno = data->threads[i].error;
				return 0;
			}
			pthread_mutex_unlock(&data->threads[i].mutex);
		}
		if (ret >= require) {
			if (set_need_a_thread(data, 0))
				return 0;
			return ret;
		}
		if (await_some_thread(data))
			return 0;
	}
}


static size_t
get_ready_threads(size_t *indices, size_t n, struct libar2_context *ctx)
{
	return await_threads(indices, n, 1, ctx);
}


static int
join_thread_pool(struct libar2_context *ctx)
{
	struct user_data *data = ctx->user_data;
	return await_threads(NULL, 0, data->nthreads, ctx) ? 0 : -1;
}


int
libar2simplified_hash(void *hash, void *msg, size_t msglen, struct libar2_argon2_parameters *params)
{
	struct user_data ctx_data;
	struct libar2_context ctx;

	memset(&ctx, 0, sizeof(ctx));
	ctx.user_data = &ctx_data;
	ctx.autoerase_message = 1;
	ctx.autoerase_salt = 1;
	ctx.allocate = allocate;
	ctx.deallocate = deallocate;
	ctx.init_thread_pool = init_thread_pool;
	ctx.get_ready_threads = get_ready_threads;
	ctx.run_thread = run_thread;
	ctx.join_thread_pool = join_thread_pool;
	ctx.destroy_thread_pool = destroy_thread_pool;

	return libar2_hash(hash, msg, msglen, params, &ctx);
}