diff options
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml-cuda.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cu | 32 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cuh | 2 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 32 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 24 | ||||
-rw-r--r-- | ggml/src/ggml.c | 117 |
6 files changed, 210 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ca57efbd..966c91c0 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2233,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SILU: ggml_cuda_op_silu(ctx, dst); break; + case GGML_UNARY_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2773,6 +2776,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_SWIGLU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9e20801..51582ed5 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + const int row = i/ne0; + const int idx = i%ne0; + const int j = row*nb1 + idx; + dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j])); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); } +static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k); @@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[0] == src0->ne[0]/2); + + swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 4cfb0479..be3d6f15 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 02794e3c..774314df 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -63,6 +63,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); @@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_SWIGLU: return ggml_is_contiguous(op->src[0]); default: return false; @@ -1597,6 +1602,33 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_SWIGLU: + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(ne0 == src0->ne[0]/2); + + id<MTLComputePipelineState> pipeline = nil; + + uint32_t n_per_row = ne0; + uint32_t stride = src0->nb[1]/sizeof(float); + + if (ne0 % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; + n /= 4; + n_per_row /= 4; + stride /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; + [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e2e45029..c1e11047 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -398,6 +398,30 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_swiglu( + device const float * src0, + device float * dst, + constant uint & ne0, + constant uint & stride, + uint tpig[[thread_position_in_grid]]) { + const uint row = tpig/ne0; + const uint idx = tpig%ne0; + const uint j = row*stride + idx; + dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j])); +} + +kernel void kernel_swiglu_4( + device const float4 * src0, + device float4 * dst, + constant uint & ne0, + constant uint & stride, + uint tpig[[thread_position_in_grid]]) { + const uint row = tpig/ne0; + const uint idx = tpig%ne0; + const uint j = row*stride + idx; + dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j])); +} + kernel void kernel_sqr( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2804accd..184a31a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2867,6 +2867,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * x[i + n]; + } +} + static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -3386,9 +3410,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "SILU", "HARDSWISH", "HARDSIGMOID", + "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); +static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -5694,6 +5719,26 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]}; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); + + result->op = GGML_OP_UNARY; + result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -12243,6 +12288,67 @@ static void ggml_compute_forward_silu( } } } + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + GGML_ASSERT(dst->ne[0] == src0->ne[0]/2); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -17289,6 +17395,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_silu(params, dst); } break; + case GGML_UNARY_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; case GGML_UNARY_OP_HARDSWISH: { ggml_compute_forward_hardswish(params, dst); @@ -19155,6 +19265,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_UNARY_OP_SWIGLU: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_SILU: { // necessary for llama @@ -19656,6 +19770,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_SWIGLU: { n_tasks = n_threads; } break; |