diff options
Diffstat (limited to 'src/core/wrdp_thpool.c')
-rw-r--r-- | src/core/wrdp_thpool.c | 770 |
1 files changed, 770 insertions, 0 deletions
diff --git a/src/core/wrdp_thpool.c b/src/core/wrdp_thpool.c new file mode 100644 index 0000000..c1f8145 --- /dev/null +++ b/src/core/wrdp_thpool.c @@ -0,0 +1,770 @@ +/* BSD-2-Clause license + * + * Copyright (c) 2018-2023 NST <www.newinfosec.ru>, sss <sss at dark-alexandr dot net>. + * + */ + +#include <errno.h> +#include <pthread.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> + +#include <ev.h> + +#include "wrdp_thpool.h" +#include "wrdp_thpool_internals.h" + +#include "webrdp_core_api.h" +#include "log.h" + +void +wrdp_thpool_destroy(wrdp_thpool *pool) +{ + /* TODO: finish this */ + /* TODO: destroy all tasks */ + /* TODO: call per thread custom_thread_deinit in each worker thread */ + /* NOTE: unfinished, unused */ + uint32_t i; + if (!pool) + { + return; + } + if (pool->threads) + { + for (i = 0; i < pool->thread_count; ++i) + { + close(pool->threads[i].pipe_fds[0]); + close(pool->threads[i].pipe_fds[1]); + } + } + if (pool->custom_pool_destroy) + { + pool->custom_pool_destroy(pool->userdata); + } + free(pool); +} + +static void *wrdp_thpool_worker_thread_loop(void *thread_); + +static void pipe_readable_cb(struct ev_loop *loop, ev_io *w, int revents); + +typedef enum +{ + pool_obj_pool, + pool_obj_thread +} thpool_obj_type; + +typedef struct +{ + union + { + wrdp_thpool *pool; + wrdp_thpool_thread *thread; + }; + thpool_obj_type receiver; +} pool_receiver_ptr; + +typedef struct +{ + union + { + wrdp_thpool *pool; + wrdp_thpool_thread *thread; + }; + thpool_obj_type sender; +} pool_sender_ptr; + +wrdp_thpool * +wrdp_thpool_create(uint16_t thread_count, uint64_t max_tasks_per_thread, + void (*custom_thread_init)(void *user_pool_data, wrdp_thpool_thread *t), + void (*custom_thread_deinit)(void *user_pool_data, wrdp_thpool_thread *t), + void (*custom_pool_create)(void *user_pool_data), + void (*custom_pool_destroy)(void *user_pool_data), + void (*pool_message_handler)(void *user_data), + void (*thread_message_handler)(void *user_data), struct ev_loop *loop, + void *user_pool_data) +{ + wrdp_thpool *pool = calloc(1, sizeof(wrdp_thpool)); + if (!pool) + { + perror("calloc"); + return 0; + } + + pool->thread_count = thread_count; + pool->max_tasks = max_tasks_per_thread; + pool->custom_pool_destroy = custom_pool_destroy; + pool->custom_thread_deinit = custom_thread_deinit; + pool->custom_thread_init = custom_thread_init; + pool->userdata = user_pool_data; + pool->pool_message_handler = pool_message_handler; + pool->thread_message_handler = thread_message_handler; + + if (custom_pool_create) + { + custom_pool_create(pool->userdata); + } + + pool->tasks_per_thread = calloc(thread_count, sizeof(uint64_t)); + if (!pool->tasks_per_thread) + { + perror("calloc"); + goto error; + } + + pool->threads = calloc(thread_count, sizeof(wrdp_thpool_thread)); + if (!(pool->threads)) + { + perror("calloc"); + goto error; + } + if (pipe(pool->pipe_fds) == -1) + { + perror("pipe"); + goto error; + } + + /* allocate memory for threads, tasks, create threads */ + { + uint32_t i; + for (i = 0; i < thread_count; ++i) + { + if (pipe(pool->threads[i].pipe_fds) == -1) + { + perror("pipe"); + goto error; + } + pool->threads[i].thread_id = i; + pool->threads[i].tasks = calloc( + sizeof(wrdp_thpool_task *), max_tasks_per_thread); + if (!pool->threads[i].tasks) + { + perror("calloc"); + goto error; + } + pool->threads[i].pool = pool; + if (pthread_create(&(pool->threads[i].thread), 0, + wrdp_thpool_worker_thread_loop, + &(pool->threads[i])) + != 0) + { + goto error; + } + } + } + /* attach pipe reed watcher to default event loop */ + { + pool_receiver_ptr *p = calloc(1, sizeof(pool_receiver_ptr)); + if (!p) + { + perror("calloc"); + goto error; + } + p->receiver = pool_obj_pool; + p->pool = pool; + ev_io_init(&(pool->ev_pipe_readable), pipe_readable_cb, + pool->pipe_fds[0], EV_READ); + pool->ev_pipe_readable.data = p; + if (loop) + { + ev_io_start(loop, &(pool->ev_pipe_readable)); + } + else + { + ev_io_start(EV_DEFAULT, &(pool->ev_pipe_readable)); + } + } + return pool; +error: + if (pool) + { + wrdp_thpool_destroy(pool); + } + return 0; +} + +typedef enum +{ + thread_msg_task_count = 1, + thread_msg_push_task, + thread_msg_task_finished, + thread_msg_userdata +} thread_msg_type; + +typedef struct +{ + thread_msg_type type; + union + { + wrdp_thpool_task *task; + uint64_t running_tasks; + void *user_data; + }; + pool_sender_ptr sender; +} thread_msg; + +static void +send_msg(int write_fd, thread_msg *msg) +{ + size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg); + left = struct_size; + while (left) + { + io_size + = write(write_fd, (char *)msg + (struct_size - left), left); + if (io_size == -1 + && (errno != EAGAIN && errno != EWOULDBLOCK + && errno != EINTR)) + { + const char *msg_ = "error: thpool pipe write failure"; + perror("write"); + log_msg((const uint8_t *)msg_, strlen(msg_), + wrdp_log_level_error, 0); + exit(EXIT_FAILURE); + } + else + { + left -= io_size; + } + } +} + +bool +wrdp_thpool_send_msg_to_thread( + wrdp_thpool *pool, uint32_t thread_id, void *user_data) +{ + size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg); + thread_msg msg; + memset(&msg, 0, struct_size); + msg.user_data = user_data; + msg.type = thread_msg_userdata; + left = struct_size; + if (thread_id >= pool->thread_count) + { + return false; + } + while (left) + { + io_size = write(pool->threads[thread_id].pipe_fds[1], + (char *)&msg + (struct_size - left), left); + if (io_size == -1 + && (errno != EAGAIN && errno != EWOULDBLOCK + && errno != EINTR)) + { + const char *msg = "thpool pipe write failure"; + perror("write"); + log_msg((const uint8_t *)msg, strlen(msg), + wrdp_log_level_error, 0); + exit(EXIT_FAILURE); + } + else + { + left -= io_size; + } + } + return true; +} + +bool +wrdp_thpool_send_msg_to_pool(wrdp_thpool *pool, void *user_data) +{ + size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg); + thread_msg msg; + memset(&msg, 0, struct_size); + msg.user_data = user_data; + msg.type = thread_msg_userdata; + left = struct_size; + while (left) + { + io_size = write(pool->pipe_fds[1], + (char *)&msg + (struct_size - left), left); + if (io_size == -1 + && (errno != EAGAIN && errno != EWOULDBLOCK + && errno != EINTR)) + { + const char *msg = "thpool pipe write failure"; + perror("write"); + log_msg((const uint8_t *)msg, strlen(msg), + wrdp_log_level_error, 0); + exit(EXIT_FAILURE); + } + else + { + left -= io_size; + } + } + return true; +} + +static thread_msg * +read_msg(int read_fd) +{ + size_t io_size = 0, struct_size = sizeof(thread_msg), left = 0; + void *buf = calloc(1, struct_size); + if (!buf) + { + perror("malloc"); + exit(EXIT_FAILURE); + } + left = struct_size; + while (left) + { + io_size = read(read_fd, buf, left); + if (io_size == -1 + && (errno != EAGAIN && errno != EWOULDBLOCK + && errno != EINTR)) + { + const char *msg = "thpool pipe read failure"; + perror("read"); + log_msg((const uint8_t *)msg, strlen(msg), + wrdp_log_level_error, 0); + exit(EXIT_FAILURE); + } + else + { + left -= io_size; + } + } + return buf; +} + +static bool +send_task_to_thread(wrdp_thpool *pool, wrdp_thpool_task *task) +{ + uint32_t i = 0, thread_id = 0, minimal_tasks = 0; + + //find thread with minimal number of running tasks or no tasks at all + for (; i < pool->thread_count; ++i) + { + uint64_t running_tasks = pool->tasks_per_thread[i]; + if (!running_tasks) + { + thread_id = i; + break; + } + if (!minimal_tasks) + { + minimal_tasks = running_tasks; + thread_id = i; + } + if (running_tasks < minimal_tasks) + { + minimal_tasks = running_tasks; + thread_id = i; + } + } + //all threads have maximum tasks + if (minimal_tasks >= pool->max_tasks) + { + return false; + } + { + thread_msg msg; + memset(&msg, 0, sizeof(thread_msg)); + msg.type = thread_msg_push_task; + msg.task = task; + msg.sender.sender = pool_obj_pool; + msg.sender.pool = pool; + task->thread = &(pool->threads[thread_id]); + if (task->task_init_cb) + { + task->task_init_cb(task, task->userdata); + } + send_msg(pool->threads[thread_id].pipe_fds[1], &msg); + } + + return true; +} + +static bool +send_task_to_thread_by_id( + wrdp_thpool *pool, wrdp_thpool_task *task, uint32_t thread_id) +{ + if (pool->tasks_per_thread[thread_id] >= pool->max_tasks) + { + return false; + } + else + { + thread_msg msg; + memset(&msg, 0, sizeof(thread_msg)); + msg.type = thread_msg_push_task; + msg.task = task; + msg.sender.sender = pool_obj_pool; + msg.sender.pool = pool; + task->thread = &(pool->threads[thread_id]); + if (task->task_init_cb) + { + task->task_init_cb(task, task->userdata); + } + send_msg(pool->threads[thread_id].pipe_fds[1], &msg); + } + return true; +} + +bool +wrdp_thread_pool_add_task(wrdp_thpool *pool, + void (*run_task)(wrdp_thpool_task *task, void *userdata), + void (*task_init_cb)(wrdp_thpool_task *task, void *userdata), + void *userdata) +{ + wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task)); + if (!task) + { + perror("malloc"); + return false; + } + task->userdata = userdata; + task->run_task = run_task; + task->task_init_cb = task_init_cb; + if (!send_task_to_thread(pool, task)) + { + goto cleanup; + } + return true; +cleanup: + if (task) + { + free(task); + } + return false; +} + +bool +wrdp_thread_pool_add_task_to_thread(wrdp_thpool *pool, + void (*run_task)(wrdp_thpool_task *task, void *userdata), + uint32_t thread_id, + void (*task_init_cb)(wrdp_thpool_task *task, void *userdata), + void *userdata) +{ + wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task)); + if (!task) + { + perror("malloc"); + return false; + } + task->userdata = userdata; + task->run_task = run_task; + task->task_init_cb = task_init_cb; + if (!send_task_to_thread_by_id(pool, task, thread_id)) + { + goto cleanup; + } + return true; +cleanup: + if (task) + { + free(task); + } + return false; +} + +bool +wrdp_thread_pool_move_task_to_thread(wrdp_thpool *pool, + void (*run_task)(wrdp_thpool_task *task, void *userdata), + void (*stop_task)(wrdp_thpool_task *current_task, void *userdata), + uint32_t thread_id, + void (*task_init_cb)(wrdp_thpool_task *task, void *userdata), + wrdp_thpool_task *current_task, void *userdata) +{ + wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task)); + if (!task) + { + perror("malloc"); + return false; + } + task->userdata = userdata; + task->run_task = run_task; + task->stop_task = stop_task; + task->task_init_cb = task_init_cb; + if (stop_task) + { + stop_task(current_task, userdata); + } + wrdp_thread_pool_destroy_task(current_task, 0); + if (!send_task_to_thread_by_id(pool, task, thread_id)) + { + goto cleanup; + } + return true; +cleanup: + if (task) + { + free(task); + } + return false; +} + +static void +pipe_readable_cb(struct ev_loop *loop, ev_io *w, int revents) +{ + pool_receiver_ptr *p = w->data; + thread_msg *in_msg = 0; + switch (p->receiver) + { + case pool_obj_thread: + { + in_msg = read_msg(p->thread->pipe_fds[0]); + if (!in_msg) + { + return; + } + switch (in_msg->type) + { + case thread_msg_push_task: + { + uint32_t i; + bool added = false; + for (i = 0; + i < p->thread->pool->max_tasks; + ++i) + { + if (p->thread->tasks[i]) + { + continue; + } + thread_msg out_msg; + memset(&out_msg, 0, + sizeof(thread_msg)); + p->thread->tasks[i] + = in_msg->task; + p->thread->running_task_count++; + if (!p->thread->tasks[i] + ->run_task) + { + /* TODO: error message + * to log */ + break; + } + p->thread->tasks[i]->run_task( + p->thread->tasks[i], + (p->thread->tasks[i] + ->userdata)); + out_msg.type + = thread_msg_task_count; + out_msg.running_tasks + = p->thread + ->running_task_count; + out_msg.sender.sender + = pool_obj_thread; + out_msg.sender.thread + = p->thread; + { + char buf[128]; + log_msg_info mi = {0}; + snprintf(buf, 127, + "Added new task to " + "thread" + " %d slot %d", + p->thread + ->thread_id, + i); + mi.buf = (uint8_t *)buf; + mi.level + = wrdp_log_level_trace; + mi.wrdp_thpool_task + = in_msg->task; + log_msg_ex(&mi); + } + send_msg(p->thread->pool + ->pipe_fds[1], + &out_msg); + added = true; + break; + } + if (!added) + { + char buf[128]; + snprintf(buf, 127, + "Error: failed to add task " + "to" + " thread %d :" + "no free slots", + p->thread->thread_id); + log_msg((const uint8_t *)buf, + strlen(buf), + wrdp_log_level_error, 0); + } + } + break; + case thread_msg_task_finished: + { + thread_msg out_msg; + size_t i = 0; + bool task_found = false; + memset(&out_msg, 0, sizeof(thread_msg)); + for (i = 0; + i < p->thread->pool->max_tasks; + ++i) + { + if (p->thread->tasks[i] + == in_msg->task) + { + task_found = true; + p->thread->tasks[i] = 0; + p->thread + ->running_task_count--; + out_msg.type + = thread_msg_task_count; + out_msg.running_tasks + = p->thread + ->running_task_count; + out_msg.sender.sender + = pool_obj_thread; + out_msg.sender.thread + = p->thread; + { + char buf[128]; + log_msg_info mi + = {0}; + snprintf(buf, + 127, + "Removed " + "task from " + "thread" + " %d slot " + "%zd", + p->thread + ->thread_id, + i); + mi.wrdp_thpool_task + = in_msg + ->task; + mi.level + = wrdp_log_level_trace; + mi.buf + = (uint8_t + *) + buf; + log_msg_ex(&mi); + } + free(p->thread + ->tasks[i]); + send_msg( + p->thread->pool + ->pipe_fds[1], + &out_msg); + break; + } + } + if (!task_found) + { + const char *msg_str + = "wrdp_thpool: " + "thread_msg_task_" + "finished: task not " + "found " + "in thread"; + log_msg( + (const uint8_t *)msg_str, + strlen(msg_str), + wrdp_log_level_error, 0); + } + } + break; + case thread_msg_userdata: + { + if (p->thread->pool + ->thread_message_handler) + { + p->thread->pool + ->thread_message_handler( + in_msg->user_data); + } + } + break; + default: + break; + } + } + break; + case pool_obj_pool: + { + in_msg = read_msg(p->pool->pipe_fds[0]); + if (!in_msg) + { + return; + } + switch (in_msg->type) + { + case thread_msg_task_count: + { + p->pool->tasks_per_thread + [in_msg->sender.thread->thread_id] + = in_msg->sender.thread + ->running_task_count; + } + break; + case thread_msg_userdata: + { + if (p->pool->pool_message_handler) + { + p->pool->pool_message_handler( + in_msg->user_data); + } + } + break; + default: + break; + } + } + default: + break; + } + if (in_msg) + { + free(in_msg); + } +} + +static void * +wrdp_thpool_worker_thread_loop(void *thread_) +{ + wrdp_thpool_thread *thread = thread_; + pool_receiver_ptr *p = calloc(1, sizeof(pool_receiver_ptr)); + if (!p) + { + perror("calloc"); + return 0; + } + p->receiver = pool_obj_thread; + p->thread = thread; + if (thread->pool->custom_thread_init) + { + thread->pool->custom_thread_init( + thread->pool->userdata, thread); + } + thread->ev_th_loop = ev_loop_new(EVFLAG_AUTO); + ev_io_init(&(thread->ev_pipe_readable), pipe_readable_cb, + thread->pipe_fds[0], EV_READ); + thread->ev_pipe_readable.data = p; + ev_io_start(thread->ev_th_loop, &(thread->ev_pipe_readable)); + ev_run(thread->ev_th_loop, 0); + return 0; +} + +void +wrdp_thread_pool_destroy_task( + wrdp_thpool_task *task, void (*on_task_destroy)(wrdp_thpool_task *task)) +{ + thread_msg msg; + + /* TODO: this should never happen, but for now just crashfix hack */ + if (!task->thread) + { + return; + } + + memset(&msg, 0, sizeof(thread_msg)); + if (on_task_destroy) + { + on_task_destroy(task); + } + msg.type = thread_msg_task_finished; + msg.task = task; + msg.sender.sender = pool_obj_thread; + msg.sender.thread = task->thread; + send_msg(task->thread->pipe_fds[1], &msg); +} |