diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
| -rw-r--r-- | ggml/src/ggml-cuda/unary.cu | 67 | ||||
| -rw-r--r-- | ggml/src/ggml-cuda/unary.cuh | 2 |
2 files changed, 69 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 51582ed5..7bc43d0f 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -43,6 +43,36 @@ static __global__ void swiglu_f32(const float * x, float * dst, const int k, con dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j])); } +static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); +} + +static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0) * y[i]; +} + +static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + float xi = x[i]; + dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -133,6 +163,21 @@ static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1); } +static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); +} + +static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); +} + +static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k); @@ -216,6 +261,28 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); } +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + cudaStream_t stream = ctx.stream(); + ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index be3d6f15..d2d478b4 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -33,3 +33,5 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
