summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2024-05-29 13:36:39 +0200
committerGitHub <noreply@github.com>2024-05-29 13:36:39 +0200
commit87bdf2a199acd62e19814d7a4d0500a04a7f09f3 (patch)
treec072bcbaeb0b2bcd0f90da543a9c588124503f48
parent00281b7be32462754618c42ed93f95743af46627 (diff)
ggml : use atomic_flag for critical section (#7598)
* ggml : use atomic_flag for critical section * add windows shims
-rw-r--r--ggml.c23
1 files changed, 14 insertions, 9 deletions
diff --git a/ggml.c b/ggml.c
index 5025ec23..d8f74f3c 100644
--- a/ggml.c
+++ b/ggml.c
@@ -60,6 +60,9 @@
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
+typedef atomic_int atomic_flag;
+
+#define ATOMIC_FLAG_INIT 0
static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
+static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
+ return InterlockedExchange(ptr, 1);
+}
+static void atomic_flag_clear(atomic_flag * ptr) {
+ InterlockedExchange(ptr, 0);
+}
typedef HANDLE pthread_t;
@@ -2883,24 +2892,20 @@ struct ggml_state {
// global state
static struct ggml_state g_state;
-static atomic_int g_state_barrier = 0;
+static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
// barrier via spin lock
inline static void ggml_critical_section_start(void) {
- int processing = atomic_fetch_add(&g_state_barrier, 1);
-
- while (processing > 0) {
- // wait for other threads to finish
- atomic_fetch_sub(&g_state_barrier, 1);
- sched_yield(); // TODO: reconsider this
- processing = atomic_fetch_add(&g_state_barrier, 1);
+ while (atomic_flag_test_and_set(&g_state_critical)) {
+ // spin
+ sched_yield();
}
}
// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
inline static void ggml_critical_section_end(void) {
- atomic_fetch_sub(&g_state_barrier, 1);
+ atomic_flag_clear(&g_state_critical);
}
#if defined(__gnu_linux__)