diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-27 17:40:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-27 17:40:59 +0300 |
commit | c7e99c88a2de7489ba2a1539b1a9025912010b70 (patch) | |
tree | 9976409b1e8fac1fc7486f2c5da05a33b8e229b5 /ggml/src/ggml.c | |
parent | bd99ed7d0afd2b12c0f5ff5c17b58486396dfe7e (diff) |
Faster Gemma2 (#27)
* soft_cap_max: initial CPU version of fused softcap + soft_max
With this vanilla CPU implementation I'm already getting a ~3% speedup
for Gemma-2-9b and a prompt of 8192 tokens.
* soft_cap_max: WIP - something is wrong with CUDA
* soft_cap_max: looks good on CPU and CUDA
* Add softcap to flash attention
Just CPU and CUDA for now (but, as we know, flash attention
on the CPU is useless in llama.cpp).
On CUDA this improves PP performance quite a bit, especially for
long contexts. E.g., for PP-16384, I now get 3777 t/s.
Without this change, one cannot use FA, and one gets 2300 t/s
(after fusing softcap and softmax), or 2000 t/s without the
fused softcap+softmax.
In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before
PR-8542 (where Johannes Gaessler has also added softcap to FA),
and PP-16384 = 3097 t/s after this PR.
* soft_cap_max: Metal
* Flash attention with softcap: Metal
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 251 |
1 files changed, 243 insertions, 8 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 60a89591..cebac584 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s } } +static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) { + int i = 0; + float max = -INFINITY; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (; i + 15 < n; i += 16) { + __m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after); + _mm512_storeu_ps(x + i, y); + vmax = _mm512_max_ps(vmax, y); + } + max = _mm512_reduce_max_ps(vmax); +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + float32x4_t vs_before = vdupq_n_f32(s_before); + float32x4_t vs_after = vdupq_n_f32(s_after); + for (; i + 3 < n; i += 4) { + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after)); + } +#endif + for (; i < n; ++i) { + x[i] = s_after*tanhf(x[i]*s_before); + max = MAX(max, x[i]); + } + return max; +} + inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { @@ -3144,6 +3179,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGSORT", "LEAKY_RELU", "SOFTCAP", + "SOFT_CAP_MAX", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -3171,7 +3207,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3233,6 +3269,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argsort(x)", "leaky_relu(x)", "k2*tanh(k1*x)", + "soft_max(k2*tanh(k1*x))", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -3260,7 +3297,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5963,6 +6000,72 @@ struct ggml_tensor * ggml_softcap_inplace( return ggml_softcap_impl(ctx, a, s_before, s_after, true); } +static struct ggml_tensor * ggml_softcap_max_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after, + bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); + } + + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[4] = {scale, max_bias, s_before, s_after}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFT_CAP_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = mask; + + return result; +} + +struct ggml_tensor * ggml_softcap_max( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after) { + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false); +} + +struct ggml_tensor * ggml_softcap_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after) { + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true); +} + + // ggml_set static struct ggml_tensor * ggml_set_impl( @@ -7493,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7520,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7540,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -13618,6 +13722,122 @@ static void ggml_compute_forward_softcap( } } +static void ggml_compute_forward_softcap_max_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); + + float values[4]; + memcpy(values, dst->op_params, sizeof(values)); + // values[0] -> scale + // values[1] -> max_bias + // values[2] -> s_before + // values[3] -> s_after + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + //const int64_t ne11 = src1 ? src1->ne[1] : 1; + + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(values[1] ) / n_head_log2); + const float m1 = powf(2.0f, -(values[1] / 2.0f) / n_head_log2); + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (values[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } + +} + +static void ggml_compute_forward_softcap_max( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_softcap_max_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_set static void ggml_compute_forward_set_f32( @@ -15919,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -15985,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask const float Mold = M; @@ -15994,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -16061,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -17477,6 +17703,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_softcap(params, tensor); } break; + case GGML_OP_SOFT_CAP_MAX: + { + ggml_compute_forward_softcap_max(params, tensor); + } break; case GGML_OP_SET: { ggml_compute_forward_set(params, tensor); @@ -18227,6 +18457,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -19240,6 +19474,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SCALE: case GGML_OP_SOFTCAP: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_CAP_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; |