summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c251
1 files changed, 243 insertions, 8 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 60a89591..cebac584 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s
}
}
+static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) {
+ int i = 0;
+ float max = -INFINITY;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 vs_before = _mm512_set1_ps(2.f*s_before);
+ __m512 vs_after = _mm512_set1_ps(s_after);
+ __m512 vmax = _mm512_set1_ps(-INFINITY);
+ for (; i + 15 < n; i += 16) {
+ __m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after);
+ _mm512_storeu_ps(x + i, y);
+ vmax = _mm512_max_ps(vmax, y);
+ }
+ max = _mm512_reduce_max_ps(vmax);
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ float32x4_t vs_before = vdupq_n_f32(s_before);
+ float32x4_t vs_after = vdupq_n_f32(s_after);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
+ }
+#endif
+ for (; i < n; ++i) {
+ x[i] = s_after*tanhf(x[i]*s_before);
+ max = MAX(max, x[i]);
+ }
+ return max;
+}
+
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
@@ -3144,6 +3179,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARGSORT",
"LEAKY_RELU",
"SOFTCAP",
+ "SOFT_CAP_MAX",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
@@ -3171,7 +3207,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3233,6 +3269,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"argsort(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
+ "soft_max(k2*tanh(k1*x))",
"flash_attn_ext(x)",
"flash_attn_back(x)",
@@ -3260,7 +3297,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5963,6 +6000,72 @@ struct ggml_tensor * ggml_softcap_inplace(
return ggml_softcap_impl(ctx, a, s_before, s_after, true);
}
+static struct ggml_tensor * ggml_softcap_max_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_is_padded_1d(a));
+
+ if (mask) {
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(ggml_is_matrix(mask));
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+ }
+
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(mask);
+ }
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ float params[4] = {scale, max_bias, s_before, s_after};
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_SOFT_CAP_MAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = mask;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_softcap_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false);
+}
+
+struct ggml_tensor * ggml_softcap_max_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true);
+}
+
+
// ggml_set
static struct ggml_tensor * ggml_set_impl(
@@ -7493,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
- float max_bias) {
+ float max_bias,
+ float softcap) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
@@ -7520,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- float params[] = { scale, max_bias };
+ float params[] = { scale, max_bias, softcap };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -7540,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec(
const int32_t prec_i32 = (int32_t) prec;
- ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
+ ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
// ggml_flash_attn_back
@@ -13618,6 +13722,122 @@ static void ggml_compute_forward_softcap(
}
}
+static void ggml_compute_forward_softcap_max_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ float values[4];
+ memcpy(values, dst->op_params, sizeof(values));
+ // values[0] -> scale
+ // values[1] -> max_bias
+ // values[2] -> s_before
+ // values[3] -> s_after
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+ // TODO: is this supposed to be ceil instead of floor?
+ // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+ const uint32_t n_head = ne02;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(values[1] ) / n_head_log2);
+ const float m1 = powf(2.0f, -(values[1] / 2.0f) / n_head_log2);
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ // ALiBi
+ const uint32_t h = (i1/ne01)%ne02; // head
+ const float slope = (values[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+ // broadcast the mask across rows
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+ ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]);
+
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*mp_f32[i];
+ }
+ }
+ }
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(wp[i]));
+ }
+#endif
+
+ float max = -INFINITY;
+ ggml_vec_max_f32(nc, &max, wp);
+
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+ assert(sum > 0.0);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(dp[i]));
+ assert(!isinf(dp[i]));
+ }
+#endif
+ }
+
+}
+
+static void ggml_compute_forward_softcap_max(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_softcap_max_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_set
static void ggml_compute_forward_set_f32(
@@ -15919,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float scale = 1.0f;
float max_bias = 0.0f;
+ float softcap = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float));
+
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@@ -15985,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
- s = s*scale + mv; // scale KQ value and apply mask
+ s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
const float Mold = M;
@@ -15994,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
- if (v->type== GGML_TYPE_F16) {
+ if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
@@ -16061,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext(
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
- switch (dst->op_params[2]) {
+ switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
@@ -17477,6 +17703,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_softcap(params, tensor);
} break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ ggml_compute_forward_softcap_max(params, tensor);
+ } break;
case GGML_OP_SET:
{
ggml_compute_forward_set(params, tensor);
@@ -18227,6 +18457,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -19240,6 +19474,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_CAP_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
} break;