summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/unary.cu36
-rw-r--r--ggml/src/ggml-cuda/unary.cuh3
2 files changed, 39 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 7bc43d0f..8ffddd6d 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -52,6 +52,25 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}
+static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) {
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+ int64_t k = ne0*ne1;
+ if (i >= k) {
+ return;
+ }
+ int i1 = i / ne0;
+ int i0 = i % ne0;
+ float * result = (float *)(dst + i1*nb1);
+ const float * s = (const float *)(src0 + i1*nb01) + i0;
+ if (nused == 1) {
+ result[i0] = s[0];
+ } else {
+ float sum = s[0] + s[ne0];
+ for (int j = 2; j < nused; ++j) sum += s[j*ne0];
+ result[i0] = sum;
+ }
+}
+
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -218,6 +237,23 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) {
+ int64_t k = ne0 * ne1;
+ const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE;
+ multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, src0, dst);
+}
+
+void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+ int nused = dst->op_params[0];
+ GGML_ASSERT(nused >= 1);
+ const char * src0 = (const char *)dst->src[0]->data;
+ cudaStream_t stream = ctx.stream();
+ multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index d2d478b4..0235a319 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -9,6 +9,7 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_MULTI_ADD_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);