diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 189 |
1 files changed, 187 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d31713df..08eab23b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2888,6 +2888,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_silu_f32(x[i]) * y[i]; + } +} + static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -3100,6 +3124,47 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif } +inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i))); + } +#elif defined __AVX2__ && defined __FMA__ + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i))); + } +#endif +#ifdef GGML_GELU_FP16 + uint16_t t; + for (; i < n; ++i) { + if (x[i] <= -10.0f) { + z[i] = 0.0f; + } else if (x[i] >= 10.0f) { + z[i] = x[i]*y[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i]; + } + } +#else +#if defined __ARM_NEON + float32x4_t c1 = vdupq_n_f32(GELU_COEF_A); + float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_gelu_f32(x[i])*y[i]; + } +#endif +} static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; @@ -3258,6 +3323,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM_BACK", "GROUP_NORM", "FUSED_RMS_NORM", + "FUSED_MUL_UNARY", "MUL_MAT", "MUL_MAT_ID", @@ -3321,7 +3387,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3349,6 +3415,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm_back(x)", "group_norm(x)", "fused_rms_norm(x)", + "fused_mul_unary(x)", "X*Y", "X[i]*Y", @@ -3412,7 +3479,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5246,6 +5313,55 @@ struct ggml_tensor * ggml_mul_inplace( struct ggml_tensor * b) { return ggml_mul_impl(ctx, a, b, true); } +// ggml_mul + +static struct ggml_tensor * ggml_fused_mul_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(b, a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_FUSED_MUL_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, false); +} + +struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, true); +} // ggml_div @@ -12374,6 +12490,66 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_fused_mul_unary + +static void ggml_compute_forward_fused_mul_unary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); + const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); + const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); + switch (op) { + case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break; + default: GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_fused_mul_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_mul_unary_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -17990,6 +18166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul(params, tensor); } break; + case GGML_OP_FUSED_MUL_UNARY: + { + ggml_compute_forward_fused_mul_unary(params, tensor); + } break; case GGML_OP_DIV: { ggml_compute_forward_div(params, tensor); @@ -18715,6 +18895,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_FUSED_MUL_UNARY: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -19813,6 +19997,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: + case GGML_OP_FUSED_MUL_UNARY: case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: |