summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/unary.cu32
-rw-r--r--ggml/src/ggml-cuda/unary.cuh2
-rw-r--r--ggml/src/ggml-metal.m32
-rw-r--r--ggml/src/ggml-metal.metal24
-rw-r--r--ggml/src/ggml.c117
6 files changed, 210 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index ca57efbd..966c91c0 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2233,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
+ case GGML_UNARY_OP_SWIGLU:
+ ggml_cuda_op_swiglu(ctx, dst);
+ break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2773,6 +2776,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index f9e20801..51582ed5 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ const int row = i/ne0;
+ const int idx = i%ne0;
+ const int j = row*nb1 + idx;
+ dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
+}
+
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+ swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
+}
+
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
+
+ swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
+}
+
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index 4cfb0479..be3d6f15 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 02794e3c..774314df 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -63,6 +63,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
+ GGML_METAL_KERNEL_TYPE_SWIGLU_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
@@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_SWIGLU:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1597,6 +1602,33 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_UNARY_OP_SWIGLU:
+ {
+ int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(ne0 == src0->ne[0]/2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ uint32_t n_per_row = ne0;
+ uint32_t stride = src0->nb[1]/sizeof(float);
+
+ if (ne0 % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
+ n /= 4;
+ n_per_row /= 4;
+ stride /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
+ [encoder setBytes:&stride length:sizeof(stride) atIndex:3];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
default:
{
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index e2e45029..c1e11047 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -398,6 +398,30 @@ kernel void kernel_silu_4(
dst[tpig] = x / (1.0f + exp(-x));
}
+kernel void kernel_swiglu(
+ device const float * src0,
+ device float * dst,
+ constant uint & ne0,
+ constant uint & stride,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint row = tpig/ne0;
+ const uint idx = tpig%ne0;
+ const uint j = row*stride + idx;
+ dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
+}
+
+kernel void kernel_swiglu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant uint & ne0,
+ constant uint & stride,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint row = tpig/ne0;
+ const uint idx = tpig%ne0;
+ const uint j = row*stride + idx;
+ dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
+}
+
kernel void kernel_sqr(
device const float * src0,
device float * dst,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 2804accd..184a31a8 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2867,6 +2867,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
+static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
+ }
+#endif
+ for (; i < n; ++i) {
+ y[i] = ggml_silu_f32(x[i]) * x[i + n];
+ }
+}
+
static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -3386,9 +3410,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"SILU",
"HARDSWISH",
"HARDSIGMOID",
+ "SWIGLU",
};
-static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5694,6 +5719,26 @@ struct ggml_tensor * ggml_silu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
}
+// ggml_swiglu
+
+struct ggml_tensor * ggml_swiglu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ GGML_ASSERT(ggml_is_contiguous_1(a));
+
+ int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
+
+ result->op = GGML_OP_UNARY;
+ result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(
@@ -12243,6 +12288,67 @@ static void ggml_compute_forward_silu(
}
}
}
+
+// ggml_compute_forward_swiglu
+
+static void ggml_compute_forward_swiglu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+ GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = dst->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_swiglu_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_swiglu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_swiglu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_leaky_relu
static void ggml_compute_forward_leaky_relu_f32(
@@ -17289,6 +17395,10 @@ static void ggml_compute_forward_unary(
{
ggml_compute_forward_silu(params, dst);
} break;
+ case GGML_UNARY_OP_SWIGLU:
+ {
+ ggml_compute_forward_swiglu(params, dst);
+ } break;
case GGML_UNARY_OP_HARDSWISH:
{
ggml_compute_forward_hardswish(params, dst);
@@ -19155,6 +19265,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
+ case GGML_UNARY_OP_SWIGLU:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_UNARY_OP_SILU:
{
// necessary for llama
@@ -19656,6 +19770,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_SWIGLU:
{
n_tasks = n_threads;
} break;