summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c189
1 files changed, 187 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index d31713df..08eab23b 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2888,6 +2888,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
+static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i)));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i)));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i)));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = ggml_silu_f32(x[i]) * y[i];
+ }
+}
+
static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -3100,6 +3124,47 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif
}
+inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
+ __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i)));
+ }
+#elif defined __AVX2__ && defined __FMA__
+ __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
+ __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i)));
+ }
+#endif
+#ifdef GGML_GELU_FP16
+ uint16_t t;
+ for (; i < n; ++i) {
+ if (x[i] <= -10.0f) {
+ z[i] = 0.0f;
+ } else if (x[i] >= 10.0f) {
+ z[i] = x[i]*y[i];
+ } else {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i];
+ }
+ }
+#else
+#if defined __ARM_NEON
+ float32x4_t c1 = vdupq_n_f32(GELU_COEF_A);
+ float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = ggml_gelu_f32(x[i])*y[i];
+ }
+#endif
+}
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
@@ -3258,6 +3323,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM_BACK",
"GROUP_NORM",
"FUSED_RMS_NORM",
+ "FUSED_MUL_UNARY",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3321,7 +3387,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3349,6 +3415,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm_back(x)",
"group_norm(x)",
"fused_rms_norm(x)",
+ "fused_mul_unary(x)",
"X*Y",
"X[i]*Y",
@@ -3412,7 +3479,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5246,6 +5313,55 @@ struct ggml_tensor * ggml_mul_inplace(
struct ggml_tensor * b) {
return ggml_mul_impl(ctx, a, b, true);
}
+// ggml_mul
+
+static struct ggml_tensor * ggml_fused_mul_unary_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(b, a));
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ GGML_ASSERT(!is_node);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
+
+ result->op = GGML_OP_FUSED_MUL_UNARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_fused_mul_unary(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op) {
+ return ggml_fused_mul_unary_impl(ctx, a, b, op, false);
+}
+
+struct ggml_tensor * ggml_fused_mul_unary_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op) {
+ return ggml_fused_mul_unary_impl(ctx, a, b, op, true);
+}
// ggml_div
@@ -12374,6 +12490,66 @@ static void ggml_compute_forward_swiglu(
}
}
+// ggml_compute_forward_fused_mul_unary
+
+static void ggml_compute_forward_fused_mul_unary_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+ GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = dst->ne[0];
+ const int nr = ggml_nrows(src0);
+
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
+ const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
+ const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1]));
+ switch (op) {
+ case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
+ case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
+ case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break;
+ default: GGML_ABORT("fatal error");
+ }
+ }
+}
+
+static void ggml_compute_forward_fused_mul_unary(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_fused_mul_unary_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_leaky_relu
static void ggml_compute_forward_leaky_relu_f32(
@@ -17990,6 +18166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_mul(params, tensor);
} break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ ggml_compute_forward_fused_mul_unary(params, tensor);
+ } break;
case GGML_OP_DIV:
{
ggml_compute_forward_div(params, tensor);
@@ -18715,6 +18895,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ GGML_ABORT("fatal error"); // TODO: implement
+ }
case GGML_OP_CONCAT:
{
GGML_ABORT("fatal error"); // TODO: implement
@@ -19813,6 +19997,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
break;
case GGML_OP_SILU_BACK:
case GGML_OP_MUL:
+ case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_DIV:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM: