diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-28 13:37:25 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-28 13:37:25 +0300 |
commit | 737514fd814d944f8ce965620293a16e5e8a285d (patch) | |
tree | 4b4b79eec0d1cbcc413dd3c6991b6d57439edd86 /ggml/src/ggml.c | |
parent | 1f61e91862dd0b077ccb60459f3cc03f364ee279 (diff) |
Adding SWIGLU unary op (#65)
* Adding GGML_UNARY_OP_SWIGLU
This commit implements the ggml op and CPU compute
forward. I see ~3-4% speedup of PP-512 for Phi-3.5-mini.
* GGML_UNARY_OP_SWIGLU: CUDA implementation
I observe ~12% speedup for PP-512(Phi-3.5-mini).
* GGML_UNARY_OP_SWIGLU: Metal implementation
We get ~2% speedup for PP-512(Phi-3.5-mini).
* GGML_UNARY_OP_SWIGLU: minor improvement on Metal
* GGML_UNARY_OP_SWIGLU: cleanup
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 117 |
1 files changed, 116 insertions, 1 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2804accd..184a31a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2867,6 +2867,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * x[i + n]; + } +} + static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -3386,9 +3410,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "SILU", "HARDSWISH", "HARDSIGMOID", + "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); +static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -5694,6 +5719,26 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]}; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); + + result->op = GGML_OP_UNARY; + result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -12243,6 +12288,67 @@ static void ggml_compute_forward_silu( } } } + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + GGML_ASSERT(dst->ne[0] == src0->ne[0]/2); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -17289,6 +17395,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_silu(params, dst); } break; + case GGML_UNARY_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; case GGML_UNARY_OP_HARDSWISH: { ggml_compute_forward_hardswish(params, dst); @@ -19155,6 +19265,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_UNARY_OP_SWIGLU: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_SILU: { // necessary for llama @@ -19656,6 +19770,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_SWIGLU: { n_tasks = n_threads; } break; |