summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/include/ggml.h13
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/unary.cu67
-rw-r--r--ggml/src/ggml-cuda/unary.cuh2
-rw-r--r--ggml/src/ggml-metal.m32
-rw-r--r--ggml/src/ggml-metal.metal50
-rw-r--r--ggml/src/ggml.c189
-rw-r--r--src/llama.cpp8
8 files changed, 363 insertions, 2 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 08fe6a3e..b1aebd21 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -487,6 +487,7 @@ extern "C" {
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_FUSED_RMS_NORM,
+ GGML_OP_FUSED_MUL_UNARY,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
@@ -963,6 +964,18 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_fused_mul_unary(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op);
+
+ GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op);
+
GGML_API struct ggml_tensor * ggml_div(
struct ggml_context * ctx,
struct ggml_tensor * a,
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 64cc7592..871d4007 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2222,6 +2222,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL:
ggml_cuda_op_mul(ctx, dst);
break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ ggml_cuda_op_fused_mul_unary(ctx, dst);
+ break;
case GGML_OP_DIV:
ggml_cuda_op_div(ctx, dst);
break;
@@ -2788,6 +2791,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
break;
+ case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 51582ed5..7bc43d0f 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -43,6 +43,36 @@ static __global__ void swiglu_f32(const float * x, float * dst, const int k, con
dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
}
+static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
+}
+
+static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fmaxf(x[i], 0) * y[i];
+}
+
+static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
+ constexpr float GELU_COEF_A = 0.044715f;
+ constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ float xi = x[i];
+ dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+}
+
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -133,6 +163,21 @@ static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int
swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
}
+static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+ fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
+static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+ fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
+static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+ fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -216,6 +261,28 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
}
+void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+
+ cudaStream_t stream = ctx.stream();
+ ggml_unary_op op = (ggml_unary_op)dst->op_params[0];
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+
+ switch (op) {
+ case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
+ case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
+ case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
+ default: GGML_ASSERT(false);
+ }
+}
+
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index be3d6f15..d2d478b4 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -33,3 +33,5 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index dcdd0efe..4badc7a7 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -56,13 +56,18 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
+ GGML_METAL_KERNEL_TYPE_MUL_RELU,
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
+ GGML_METAL_KERNEL_TYPE_MUL_GELU,
+ GGML_METAL_KERNEL_TYPE_MUL_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
+ GGML_METAL_KERNEL_TYPE_MUL_SILU,
+ GGML_METAL_KERNEL_TYPE_MUL_SILU_4,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
@@ -584,13 +589,18 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_RELU, mul_relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU, mul_gelu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU_4, mul_gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU, mul_silu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU_4, mul_silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
@@ -921,6 +931,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true;
+ case GGML_OP_FUSED_MUL_UNARY:
+ return ggml_is_contiguous(op->src[0]);
case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_CAP_MAX:
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
@@ -1648,6 +1660,26 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ABORT("fatal error");
}
} break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ int64_t n = ggml_nelements(dst);
+ enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
+ id<MTLComputePipelineState> pipeline = nil;
+ if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) {
+ pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline
+ : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline
+ : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline
+ : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline;
+ }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_SQR:
{
GGML_ASSERT(ggml_is_contiguous(src0));
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 225fa5f1..4dbfa089 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -323,6 +323,14 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}
+kernel void kernel_mul_relu(
+ device const float * src0,
+ device const float * src1,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = max(0.0f, src0[tpig]) * src1[tpig];
+}
+
kernel void kernel_sigmoid(
device const float * src0,
device float * dst,
@@ -364,6 +372,30 @@ kernel void kernel_gelu_4(
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
+kernel void kernel_mul_gelu(
+ device const float * src0,
+ device const float * src1,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_mul_gelu_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
kernel void kernel_gelu_quick(
device const float * src0,
device float * dst,
@@ -398,6 +430,24 @@ kernel void kernel_silu_4(
dst[tpig] = x / (1.0f + exp(-x));
}
+kernel void kernel_mul_silu(
+ device const float * src0,
+ device const float * src1,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
+}
+
+kernel void kernel_mul_silu_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
+}
+
kernel void kernel_swiglu(
device const float * src0,
device float * dst,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index d31713df..08eab23b 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2888,6 +2888,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
+static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i)));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i)));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i)));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = ggml_silu_f32(x[i]) * y[i];
+ }
+}
+
static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -3100,6 +3124,47 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif
}
+inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
+ __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i)));
+ }
+#elif defined __AVX2__ && defined __FMA__
+ __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
+ __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i)));
+ }
+#endif
+#ifdef GGML_GELU_FP16
+ uint16_t t;
+ for (; i < n; ++i) {
+ if (x[i] <= -10.0f) {
+ z[i] = 0.0f;
+ } else if (x[i] >= 10.0f) {
+ z[i] = x[i]*y[i];
+ } else {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i];
+ }
+ }
+#else
+#if defined __ARM_NEON
+ float32x4_t c1 = vdupq_n_f32(GELU_COEF_A);
+ float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = ggml_gelu_f32(x[i])*y[i];
+ }
+#endif
+}
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
@@ -3258,6 +3323,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM_BACK",
"GROUP_NORM",
"FUSED_RMS_NORM",
+ "FUSED_MUL_UNARY",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3321,7 +3387,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3349,6 +3415,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm_back(x)",
"group_norm(x)",
"fused_rms_norm(x)",
+ "fused_mul_unary(x)",
"X*Y",
"X[i]*Y",
@@ -3412,7 +3479,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5246,6 +5313,55 @@ struct ggml_tensor * ggml_mul_inplace(
struct ggml_tensor * b) {
return ggml_mul_impl(ctx, a, b, true);
}
+// ggml_mul
+
+static struct ggml_tensor * ggml_fused_mul_unary_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(b, a));
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ GGML_ASSERT(!is_node);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
+
+ result->op = GGML_OP_FUSED_MUL_UNARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_fused_mul_unary(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op) {
+ return ggml_fused_mul_unary_impl(ctx, a, b, op, false);
+}
+
+struct ggml_tensor * ggml_fused_mul_unary_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_unary_op op) {
+ return ggml_fused_mul_unary_impl(ctx, a, b, op, true);
+}
// ggml_div
@@ -12374,6 +12490,66 @@ static void ggml_compute_forward_swiglu(
}
}
+// ggml_compute_forward_fused_mul_unary
+
+static void ggml_compute_forward_fused_mul_unary_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+ GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = dst->ne[0];
+ const int nr = ggml_nrows(src0);
+
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
+ const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
+ const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1]));
+ switch (op) {
+ case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
+ case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
+ case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break;
+ default: GGML_ABORT("fatal error");
+ }
+ }
+}
+
+static void ggml_compute_forward_fused_mul_unary(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_fused_mul_unary_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_leaky_relu
static void ggml_compute_forward_leaky_relu_f32(
@@ -17990,6 +18166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_mul(params, tensor);
} break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ ggml_compute_forward_fused_mul_unary(params, tensor);
+ } break;
case GGML_OP_DIV:
{
ggml_compute_forward_div(params, tensor);
@@ -18715,6 +18895,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ GGML_ABORT("fatal error"); // TODO: implement
+ }
case GGML_OP_CONCAT:
{
GGML_ABORT("fatal error"); // TODO: implement
@@ -19813,6 +19997,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
break;
case GGML_OP_SILU_BACK:
case GGML_OP_MUL:
+ case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_DIV:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
diff --git a/src/llama.cpp b/src/llama.cpp
index eb982125..9ed109c6 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8083,6 +8083,13 @@ static struct ggml_tensor * llm_build_ffn(
cur = tmp;
}
+ if (type_gate == LLM_FFN_PAR &&
+ (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
+ cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
+ type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU);
+ }
+ else {
+
switch (type_op) {
case LLM_FFN_SILU:
{
@@ -8122,6 +8129,7 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_mul(ctx, cur, tmp);
cb(cur, "ffn_gate_par", il);
}
+ }
if (down) {
cur = llm_build_lora_mm(lctx, ctx, down, cur);