diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 251 |
1 files changed, 243 insertions, 8 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 60a89591..cebac584 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s } } +static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) { + int i = 0; + float max = -INFINITY; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (; i + 15 < n; i += 16) { + __m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after); + _mm512_storeu_ps(x + i, y); + vmax = _mm512_max_ps(vmax, y); + } + max = _mm512_reduce_max_ps(vmax); +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + float32x4_t vs_before = vdupq_n_f32(s_before); + float32x4_t vs_after = vdupq_n_f32(s_after); + for (; i + 3 < n; i += 4) { + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after)); + } +#endif + for (; i < n; ++i) { + x[i] = s_after*tanhf(x[i]*s_before); + max = MAX(max, x[i]); + } + return max; +} + inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { @@ -3144,6 +3179,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGSORT", "LEAKY_RELU", "SOFTCAP", + "SOFT_CAP_MAX", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -3171,7 +3207,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3233,6 +3269,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argsort(x)", "leaky_relu(x)", "k2*tanh(k1*x)", + "soft_max(k2*tanh(k1*x))", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -3260,7 +3297,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5963,6 +6000,72 @@ struct ggml_tensor * ggml_softcap_inplace( return ggml_softcap_impl(ctx, a, s_before, s_after, true); } +static struct ggml_tensor * ggml_softcap_max_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after, + bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); + } + + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[4] = {scale, max_bias, s_before, s_after}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFT_CAP_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = mask; + + return result; +} + +struct ggml_tensor * ggml_softcap_max( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after) { + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false); +} + +struct ggml_tensor * ggml_softcap_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after) { + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true); +} + + // ggml_set static struct ggml_tensor * ggml_set_impl( @@ -7493,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7520,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7540,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -13618,6 +13722,122 @@ static void ggml_compute_forward_softcap( } } +static void ggml_compute_forward_softcap_max_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); + + float values[4]; + memcpy(values, dst->op_params, sizeof(values)); + // values[0] -> scale + // values[1] -> max_bias + // values[2] -> s_before + // values[3] -> s_after + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + //const int64_t ne11 = src1 ? src1->ne[1] : 1; + + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(values[1] ) / n_head_log2); + const float m1 = powf(2.0f, -(values[1] / 2.0f) / n_head_log2); + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (values[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } + +} + +static void ggml_compute_forward_softcap_max( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_softcap_max_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_set static void ggml_compute_forward_set_f32( @@ -15919,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -15985,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask const float Mold = M; @@ -15994,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -16061,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -17477,6 +17703,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_softcap(params, tensor); } break; + case GGML_OP_SOFT_CAP_MAX: + { + ggml_compute_forward_softcap_max(params, tensor); + } break; case GGML_OP_SET: { ggml_compute_forward_set(params, tensor); @@ -18227,6 +18457,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -19240,6 +19474,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SCALE: case GGML_OP_SOFTCAP: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_CAP_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; |