summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-20 08:21:25 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit36374ab37dac8fadb634f802aaa3ee7b816fe727 (patch)
treebdfb7cedbbca6e0716aaf6deca2046a1a28f6b16 /ggml.c
parente73ae1f6d31074f774741a592382ec62a9de6dbf (diff)
bitnet(scale in a separate tensor): mul -> scale on the CPU
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c17
1 files changed, 17 insertions, 0 deletions
diff --git a/ggml.c b/ggml.c
index 4d089c43..eac46c42 100644
--- a/ggml.c
+++ b/ggml.c
@@ -10156,6 +10156,23 @@ static void ggml_compute_forward_mul_f32(
const int ith = params->ith;
const int nth = params->nth;
+ if (ggml_nelements(dst->src[1]) == 1 && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst) &&
+ dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ int64_t nelements = ggml_nelements(dst->src[0]);
+ int64_t n_per_thread = (nelements + nth - 1)/nth;
+ n_per_thread = MAX(1024, n_per_thread);
+ int64_t start = n_per_thread*ith;
+ if (start >= nelements) return;
+ int64_t end = MIN(nelements, start + n_per_thread);
+ const float * src = (const float *)dst->src[0]->data + start;
+ float * res = (float *)dst->data + start;
+ if (res != src) {
+ memcpy(res, src, (end - start)*sizeof(float));
+ }
+ ggml_vec_scale_f32(end - start, res, *(const float *)dst->src[1]->data);
+ return;
+ }
+
const int64_t nr = ggml_nrows(src0);
GGML_TENSOR_BINARY_OP_LOCALS