summaryrefslogtreecommitdiff
path: root/src/core/wrdp_thpool.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/wrdp_thpool.c')
-rw-r--r--src/core/wrdp_thpool.c770
1 files changed, 770 insertions, 0 deletions
diff --git a/src/core/wrdp_thpool.c b/src/core/wrdp_thpool.c
new file mode 100644
index 0000000..c1f8145
--- /dev/null
+++ b/src/core/wrdp_thpool.c
@@ -0,0 +1,770 @@
+/* BSD-2-Clause license
+ *
+ * Copyright (c) 2018-2023 NST <www.newinfosec.ru>, sss <sss at dark-alexandr dot net>.
+ *
+ */
+
+#include <errno.h>
+#include <pthread.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <ev.h>
+
+#include "wrdp_thpool.h"
+#include "wrdp_thpool_internals.h"
+
+#include "webrdp_core_api.h"
+#include "log.h"
+
+void
+wrdp_thpool_destroy(wrdp_thpool *pool)
+{
+ /* TODO: finish this */
+ /* TODO: destroy all tasks */
+ /* TODO: call per thread custom_thread_deinit in each worker thread */
+ /* NOTE: unfinished, unused */
+ uint32_t i;
+ if (!pool)
+ {
+ return;
+ }
+ if (pool->threads)
+ {
+ for (i = 0; i < pool->thread_count; ++i)
+ {
+ close(pool->threads[i].pipe_fds[0]);
+ close(pool->threads[i].pipe_fds[1]);
+ }
+ }
+ if (pool->custom_pool_destroy)
+ {
+ pool->custom_pool_destroy(pool->userdata);
+ }
+ free(pool);
+}
+
+static void *wrdp_thpool_worker_thread_loop(void *thread_);
+
+static void pipe_readable_cb(struct ev_loop *loop, ev_io *w, int revents);
+
+typedef enum
+{
+ pool_obj_pool,
+ pool_obj_thread
+} thpool_obj_type;
+
+typedef struct
+{
+ union
+ {
+ wrdp_thpool *pool;
+ wrdp_thpool_thread *thread;
+ };
+ thpool_obj_type receiver;
+} pool_receiver_ptr;
+
+typedef struct
+{
+ union
+ {
+ wrdp_thpool *pool;
+ wrdp_thpool_thread *thread;
+ };
+ thpool_obj_type sender;
+} pool_sender_ptr;
+
+wrdp_thpool *
+wrdp_thpool_create(uint16_t thread_count, uint64_t max_tasks_per_thread,
+ void (*custom_thread_init)(void *user_pool_data, wrdp_thpool_thread *t),
+ void (*custom_thread_deinit)(void *user_pool_data, wrdp_thpool_thread *t),
+ void (*custom_pool_create)(void *user_pool_data),
+ void (*custom_pool_destroy)(void *user_pool_data),
+ void (*pool_message_handler)(void *user_data),
+ void (*thread_message_handler)(void *user_data), struct ev_loop *loop,
+ void *user_pool_data)
+{
+ wrdp_thpool *pool = calloc(1, sizeof(wrdp_thpool));
+ if (!pool)
+ {
+ perror("calloc");
+ return 0;
+ }
+
+ pool->thread_count = thread_count;
+ pool->max_tasks = max_tasks_per_thread;
+ pool->custom_pool_destroy = custom_pool_destroy;
+ pool->custom_thread_deinit = custom_thread_deinit;
+ pool->custom_thread_init = custom_thread_init;
+ pool->userdata = user_pool_data;
+ pool->pool_message_handler = pool_message_handler;
+ pool->thread_message_handler = thread_message_handler;
+
+ if (custom_pool_create)
+ {
+ custom_pool_create(pool->userdata);
+ }
+
+ pool->tasks_per_thread = calloc(thread_count, sizeof(uint64_t));
+ if (!pool->tasks_per_thread)
+ {
+ perror("calloc");
+ goto error;
+ }
+
+ pool->threads = calloc(thread_count, sizeof(wrdp_thpool_thread));
+ if (!(pool->threads))
+ {
+ perror("calloc");
+ goto error;
+ }
+ if (pipe(pool->pipe_fds) == -1)
+ {
+ perror("pipe");
+ goto error;
+ }
+
+ /* allocate memory for threads, tasks, create threads */
+ {
+ uint32_t i;
+ for (i = 0; i < thread_count; ++i)
+ {
+ if (pipe(pool->threads[i].pipe_fds) == -1)
+ {
+ perror("pipe");
+ goto error;
+ }
+ pool->threads[i].thread_id = i;
+ pool->threads[i].tasks = calloc(
+ sizeof(wrdp_thpool_task *), max_tasks_per_thread);
+ if (!pool->threads[i].tasks)
+ {
+ perror("calloc");
+ goto error;
+ }
+ pool->threads[i].pool = pool;
+ if (pthread_create(&(pool->threads[i].thread), 0,
+ wrdp_thpool_worker_thread_loop,
+ &(pool->threads[i]))
+ != 0)
+ {
+ goto error;
+ }
+ }
+ }
+ /* attach pipe reed watcher to default event loop */
+ {
+ pool_receiver_ptr *p = calloc(1, sizeof(pool_receiver_ptr));
+ if (!p)
+ {
+ perror("calloc");
+ goto error;
+ }
+ p->receiver = pool_obj_pool;
+ p->pool = pool;
+ ev_io_init(&(pool->ev_pipe_readable), pipe_readable_cb,
+ pool->pipe_fds[0], EV_READ);
+ pool->ev_pipe_readable.data = p;
+ if (loop)
+ {
+ ev_io_start(loop, &(pool->ev_pipe_readable));
+ }
+ else
+ {
+ ev_io_start(EV_DEFAULT, &(pool->ev_pipe_readable));
+ }
+ }
+ return pool;
+error:
+ if (pool)
+ {
+ wrdp_thpool_destroy(pool);
+ }
+ return 0;
+}
+
+typedef enum
+{
+ thread_msg_task_count = 1,
+ thread_msg_push_task,
+ thread_msg_task_finished,
+ thread_msg_userdata
+} thread_msg_type;
+
+typedef struct
+{
+ thread_msg_type type;
+ union
+ {
+ wrdp_thpool_task *task;
+ uint64_t running_tasks;
+ void *user_data;
+ };
+ pool_sender_ptr sender;
+} thread_msg;
+
+static void
+send_msg(int write_fd, thread_msg *msg)
+{
+ size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg);
+ left = struct_size;
+ while (left)
+ {
+ io_size
+ = write(write_fd, (char *)msg + (struct_size - left), left);
+ if (io_size == -1
+ && (errno != EAGAIN && errno != EWOULDBLOCK
+ && errno != EINTR))
+ {
+ const char *msg_ = "error: thpool pipe write failure";
+ perror("write");
+ log_msg((const uint8_t *)msg_, strlen(msg_),
+ wrdp_log_level_error, 0);
+ exit(EXIT_FAILURE);
+ }
+ else
+ {
+ left -= io_size;
+ }
+ }
+}
+
+bool
+wrdp_thpool_send_msg_to_thread(
+ wrdp_thpool *pool, uint32_t thread_id, void *user_data)
+{
+ size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg);
+ thread_msg msg;
+ memset(&msg, 0, struct_size);
+ msg.user_data = user_data;
+ msg.type = thread_msg_userdata;
+ left = struct_size;
+ if (thread_id >= pool->thread_count)
+ {
+ return false;
+ }
+ while (left)
+ {
+ io_size = write(pool->threads[thread_id].pipe_fds[1],
+ (char *)&msg + (struct_size - left), left);
+ if (io_size == -1
+ && (errno != EAGAIN && errno != EWOULDBLOCK
+ && errno != EINTR))
+ {
+ const char *msg = "thpool pipe write failure";
+ perror("write");
+ log_msg((const uint8_t *)msg, strlen(msg),
+ wrdp_log_level_error, 0);
+ exit(EXIT_FAILURE);
+ }
+ else
+ {
+ left -= io_size;
+ }
+ }
+ return true;
+}
+
+bool
+wrdp_thpool_send_msg_to_pool(wrdp_thpool *pool, void *user_data)
+{
+ size_t io_size = 0, left = 0, struct_size = sizeof(thread_msg);
+ thread_msg msg;
+ memset(&msg, 0, struct_size);
+ msg.user_data = user_data;
+ msg.type = thread_msg_userdata;
+ left = struct_size;
+ while (left)
+ {
+ io_size = write(pool->pipe_fds[1],
+ (char *)&msg + (struct_size - left), left);
+ if (io_size == -1
+ && (errno != EAGAIN && errno != EWOULDBLOCK
+ && errno != EINTR))
+ {
+ const char *msg = "thpool pipe write failure";
+ perror("write");
+ log_msg((const uint8_t *)msg, strlen(msg),
+ wrdp_log_level_error, 0);
+ exit(EXIT_FAILURE);
+ }
+ else
+ {
+ left -= io_size;
+ }
+ }
+ return true;
+}
+
+static thread_msg *
+read_msg(int read_fd)
+{
+ size_t io_size = 0, struct_size = sizeof(thread_msg), left = 0;
+ void *buf = calloc(1, struct_size);
+ if (!buf)
+ {
+ perror("malloc");
+ exit(EXIT_FAILURE);
+ }
+ left = struct_size;
+ while (left)
+ {
+ io_size = read(read_fd, buf, left);
+ if (io_size == -1
+ && (errno != EAGAIN && errno != EWOULDBLOCK
+ && errno != EINTR))
+ {
+ const char *msg = "thpool pipe read failure";
+ perror("read");
+ log_msg((const uint8_t *)msg, strlen(msg),
+ wrdp_log_level_error, 0);
+ exit(EXIT_FAILURE);
+ }
+ else
+ {
+ left -= io_size;
+ }
+ }
+ return buf;
+}
+
+static bool
+send_task_to_thread(wrdp_thpool *pool, wrdp_thpool_task *task)
+{
+ uint32_t i = 0, thread_id = 0, minimal_tasks = 0;
+
+ //find thread with minimal number of running tasks or no tasks at all
+ for (; i < pool->thread_count; ++i)
+ {
+ uint64_t running_tasks = pool->tasks_per_thread[i];
+ if (!running_tasks)
+ {
+ thread_id = i;
+ break;
+ }
+ if (!minimal_tasks)
+ {
+ minimal_tasks = running_tasks;
+ thread_id = i;
+ }
+ if (running_tasks < minimal_tasks)
+ {
+ minimal_tasks = running_tasks;
+ thread_id = i;
+ }
+ }
+ //all threads have maximum tasks
+ if (minimal_tasks >= pool->max_tasks)
+ {
+ return false;
+ }
+ {
+ thread_msg msg;
+ memset(&msg, 0, sizeof(thread_msg));
+ msg.type = thread_msg_push_task;
+ msg.task = task;
+ msg.sender.sender = pool_obj_pool;
+ msg.sender.pool = pool;
+ task->thread = &(pool->threads[thread_id]);
+ if (task->task_init_cb)
+ {
+ task->task_init_cb(task, task->userdata);
+ }
+ send_msg(pool->threads[thread_id].pipe_fds[1], &msg);
+ }
+
+ return true;
+}
+
+static bool
+send_task_to_thread_by_id(
+ wrdp_thpool *pool, wrdp_thpool_task *task, uint32_t thread_id)
+{
+ if (pool->tasks_per_thread[thread_id] >= pool->max_tasks)
+ {
+ return false;
+ }
+ else
+ {
+ thread_msg msg;
+ memset(&msg, 0, sizeof(thread_msg));
+ msg.type = thread_msg_push_task;
+ msg.task = task;
+ msg.sender.sender = pool_obj_pool;
+ msg.sender.pool = pool;
+ task->thread = &(pool->threads[thread_id]);
+ if (task->task_init_cb)
+ {
+ task->task_init_cb(task, task->userdata);
+ }
+ send_msg(pool->threads[thread_id].pipe_fds[1], &msg);
+ }
+ return true;
+}
+
+bool
+wrdp_thread_pool_add_task(wrdp_thpool *pool,
+ void (*run_task)(wrdp_thpool_task *task, void *userdata),
+ void (*task_init_cb)(wrdp_thpool_task *task, void *userdata),
+ void *userdata)
+{
+ wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task));
+ if (!task)
+ {
+ perror("malloc");
+ return false;
+ }
+ task->userdata = userdata;
+ task->run_task = run_task;
+ task->task_init_cb = task_init_cb;
+ if (!send_task_to_thread(pool, task))
+ {
+ goto cleanup;
+ }
+ return true;
+cleanup:
+ if (task)
+ {
+ free(task);
+ }
+ return false;
+}
+
+bool
+wrdp_thread_pool_add_task_to_thread(wrdp_thpool *pool,
+ void (*run_task)(wrdp_thpool_task *task, void *userdata),
+ uint32_t thread_id,
+ void (*task_init_cb)(wrdp_thpool_task *task, void *userdata),
+ void *userdata)
+{
+ wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task));
+ if (!task)
+ {
+ perror("malloc");
+ return false;
+ }
+ task->userdata = userdata;
+ task->run_task = run_task;
+ task->task_init_cb = task_init_cb;
+ if (!send_task_to_thread_by_id(pool, task, thread_id))
+ {
+ goto cleanup;
+ }
+ return true;
+cleanup:
+ if (task)
+ {
+ free(task);
+ }
+ return false;
+}
+
+bool
+wrdp_thread_pool_move_task_to_thread(wrdp_thpool *pool,
+ void (*run_task)(wrdp_thpool_task *task, void *userdata),
+ void (*stop_task)(wrdp_thpool_task *current_task, void *userdata),
+ uint32_t thread_id,
+ void (*task_init_cb)(wrdp_thpool_task *task, void *userdata),
+ wrdp_thpool_task *current_task, void *userdata)
+{
+ wrdp_thpool_task *task = calloc(1, sizeof(wrdp_thpool_task));
+ if (!task)
+ {
+ perror("malloc");
+ return false;
+ }
+ task->userdata = userdata;
+ task->run_task = run_task;
+ task->stop_task = stop_task;
+ task->task_init_cb = task_init_cb;
+ if (stop_task)
+ {
+ stop_task(current_task, userdata);
+ }
+ wrdp_thread_pool_destroy_task(current_task, 0);
+ if (!send_task_to_thread_by_id(pool, task, thread_id))
+ {
+ goto cleanup;
+ }
+ return true;
+cleanup:
+ if (task)
+ {
+ free(task);
+ }
+ return false;
+}
+
+static void
+pipe_readable_cb(struct ev_loop *loop, ev_io *w, int revents)
+{
+ pool_receiver_ptr *p = w->data;
+ thread_msg *in_msg = 0;
+ switch (p->receiver)
+ {
+ case pool_obj_thread:
+ {
+ in_msg = read_msg(p->thread->pipe_fds[0]);
+ if (!in_msg)
+ {
+ return;
+ }
+ switch (in_msg->type)
+ {
+ case thread_msg_push_task:
+ {
+ uint32_t i;
+ bool added = false;
+ for (i = 0;
+ i < p->thread->pool->max_tasks;
+ ++i)
+ {
+ if (p->thread->tasks[i])
+ {
+ continue;
+ }
+ thread_msg out_msg;
+ memset(&out_msg, 0,
+ sizeof(thread_msg));
+ p->thread->tasks[i]
+ = in_msg->task;
+ p->thread->running_task_count++;
+ if (!p->thread->tasks[i]
+ ->run_task)
+ {
+ /* TODO: error message
+ * to log */
+ break;
+ }
+ p->thread->tasks[i]->run_task(
+ p->thread->tasks[i],
+ (p->thread->tasks[i]
+ ->userdata));
+ out_msg.type
+ = thread_msg_task_count;
+ out_msg.running_tasks
+ = p->thread
+ ->running_task_count;
+ out_msg.sender.sender
+ = pool_obj_thread;
+ out_msg.sender.thread
+ = p->thread;
+ {
+ char buf[128];
+ log_msg_info mi = {0};
+ snprintf(buf, 127,
+ "Added new task to "
+ "thread"
+ " %d slot %d",
+ p->thread
+ ->thread_id,
+ i);
+ mi.buf = (uint8_t *)buf;
+ mi.level
+ = wrdp_log_level_trace;
+ mi.wrdp_thpool_task
+ = in_msg->task;
+ log_msg_ex(&mi);
+ }
+ send_msg(p->thread->pool
+ ->pipe_fds[1],
+ &out_msg);
+ added = true;
+ break;
+ }
+ if (!added)
+ {
+ char buf[128];
+ snprintf(buf, 127,
+ "Error: failed to add task "
+ "to"
+ " thread %d :"
+ "no free slots",
+ p->thread->thread_id);
+ log_msg((const uint8_t *)buf,
+ strlen(buf),
+ wrdp_log_level_error, 0);
+ }
+ }
+ break;
+ case thread_msg_task_finished:
+ {
+ thread_msg out_msg;
+ size_t i = 0;
+ bool task_found = false;
+ memset(&out_msg, 0, sizeof(thread_msg));
+ for (i = 0;
+ i < p->thread->pool->max_tasks;
+ ++i)
+ {
+ if (p->thread->tasks[i]
+ == in_msg->task)
+ {
+ task_found = true;
+ p->thread->tasks[i] = 0;
+ p->thread
+ ->running_task_count--;
+ out_msg.type
+ = thread_msg_task_count;
+ out_msg.running_tasks
+ = p->thread
+ ->running_task_count;
+ out_msg.sender.sender
+ = pool_obj_thread;
+ out_msg.sender.thread
+ = p->thread;
+ {
+ char buf[128];
+ log_msg_info mi
+ = {0};
+ snprintf(buf,
+ 127,
+ "Removed "
+ "task from "
+ "thread"
+ " %d slot "
+ "%zd",
+ p->thread
+ ->thread_id,
+ i);
+ mi.wrdp_thpool_task
+ = in_msg
+ ->task;
+ mi.level
+ = wrdp_log_level_trace;
+ mi.buf
+ = (uint8_t
+ *)
+ buf;
+ log_msg_ex(&mi);
+ }
+ free(p->thread
+ ->tasks[i]);
+ send_msg(
+ p->thread->pool
+ ->pipe_fds[1],
+ &out_msg);
+ break;
+ }
+ }
+ if (!task_found)
+ {
+ const char *msg_str
+ = "wrdp_thpool: "
+ "thread_msg_task_"
+ "finished: task not "
+ "found "
+ "in thread";
+ log_msg(
+ (const uint8_t *)msg_str,
+ strlen(msg_str),
+ wrdp_log_level_error, 0);
+ }
+ }
+ break;
+ case thread_msg_userdata:
+ {
+ if (p->thread->pool
+ ->thread_message_handler)
+ {
+ p->thread->pool
+ ->thread_message_handler(
+ in_msg->user_data);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ break;
+ case pool_obj_pool:
+ {
+ in_msg = read_msg(p->pool->pipe_fds[0]);
+ if (!in_msg)
+ {
+ return;
+ }
+ switch (in_msg->type)
+ {
+ case thread_msg_task_count:
+ {
+ p->pool->tasks_per_thread
+ [in_msg->sender.thread->thread_id]
+ = in_msg->sender.thread
+ ->running_task_count;
+ }
+ break;
+ case thread_msg_userdata:
+ {
+ if (p->pool->pool_message_handler)
+ {
+ p->pool->pool_message_handler(
+ in_msg->user_data);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ default:
+ break;
+ }
+ if (in_msg)
+ {
+ free(in_msg);
+ }
+}
+
+static void *
+wrdp_thpool_worker_thread_loop(void *thread_)
+{
+ wrdp_thpool_thread *thread = thread_;
+ pool_receiver_ptr *p = calloc(1, sizeof(pool_receiver_ptr));
+ if (!p)
+ {
+ perror("calloc");
+ return 0;
+ }
+ p->receiver = pool_obj_thread;
+ p->thread = thread;
+ if (thread->pool->custom_thread_init)
+ {
+ thread->pool->custom_thread_init(
+ thread->pool->userdata, thread);
+ }
+ thread->ev_th_loop = ev_loop_new(EVFLAG_AUTO);
+ ev_io_init(&(thread->ev_pipe_readable), pipe_readable_cb,
+ thread->pipe_fds[0], EV_READ);
+ thread->ev_pipe_readable.data = p;
+ ev_io_start(thread->ev_th_loop, &(thread->ev_pipe_readable));
+ ev_run(thread->ev_th_loop, 0);
+ return 0;
+}
+
+void
+wrdp_thread_pool_destroy_task(
+ wrdp_thpool_task *task, void (*on_task_destroy)(wrdp_thpool_task *task))
+{
+ thread_msg msg;
+
+ /* TODO: this should never happen, but for now just crashfix hack */
+ if (!task->thread)
+ {
+ return;
+ }
+
+ memset(&msg, 0, sizeof(thread_msg));
+ if (on_task_destroy)
+ {
+ on_task_destroy(task);
+ }
+ msg.type = thread_msg_task_finished;
+ msg.task = task;
+ msg.sender.sender = pool_obj_thread;
+ msg.sender.thread = task->thread;
+ send_msg(task->thread->pipe_fds[1], &msg);
+}