diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-28 13:37:25 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-28 13:37:25 +0300 |
commit | 737514fd814d944f8ce965620293a16e5e8a285d (patch) | |
tree | 4b4b79eec0d1cbcc413dd3c6991b6d57439edd86 /ggml/src/ggml-cuda | |
parent | 1f61e91862dd0b077ccb60459f3cc03f364ee279 (diff) |
Adding SWIGLU unary op (#65)
* Adding GGML_UNARY_OP_SWIGLU
This commit implements the ggml op and CPU compute
forward. I see ~3-4% speedup of PP-512 for Phi-3.5-mini.
* GGML_UNARY_OP_SWIGLU: CUDA implementation
I observe ~12% speedup for PP-512(Phi-3.5-mini).
* GGML_UNARY_OP_SWIGLU: Metal implementation
We get ~2% speedup for PP-512(Phi-3.5-mini).
* GGML_UNARY_OP_SWIGLU: minor improvement on Metal
* GGML_UNARY_OP_SWIGLU: cleanup
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/unary.cu | 32 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cuh | 2 |
2 files changed, 34 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9e20801..51582ed5 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + const int row = i/ne0; + const int idx = i%ne0; + const int j = row*nb1 + idx; + dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j])); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); } +static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k); @@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[0] == src0->ne[0]/2); + + swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 4cfb0479..be3d6f15 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |