diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 17 |
1 files changed, 17 insertions, 0 deletions
@@ -10156,6 +10156,23 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; + if (ggml_nelements(dst->src[1]) == 1 && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst) && + dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + int64_t nelements = ggml_nelements(dst->src[0]); + int64_t n_per_thread = (nelements + nth - 1)/nth; + n_per_thread = MAX(1024, n_per_thread); + int64_t start = n_per_thread*ith; + if (start >= nelements) return; + int64_t end = MIN(nelements, start + n_per_thread); + const float * src = (const float *)dst->src[0]->data + start; + float * res = (float *)dst->data + start; + if (res != src) { + memcpy(res, src, (end - start)*sizeof(float)); + } + ggml_vec_scale_f32(end - start, res, *(const float *)dst->src[1]->data); + return; + } + const int64_t nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS |