summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorCalvin Laurenson <calvin@laurenson.dev>2024-06-16 15:23:04 -0700
committerGitHub <noreply@github.com>2024-06-17 00:23:04 +0200
commit43b35e38ba371f9a7faa6dca4c5d1e8f698ffd87 (patch)
tree11f250899027f3249c9ee15ffaff2048c9b81268 /ggml-cuda
parent19b7a836f6658e18e973af532a5cc6ad6b3a27f8 (diff)
Add support for sqrt on CUDA (#7953)
* cuda sqrt support * enable cuda in pca * fix comments in pca * add test * add sqrt to ggml_backend_cuda_supports_op * fix test * new line * Use F32 sqrtf instead of F64 sqrt Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/unary.cu28
-rw-r--r--ggml-cuda/unary.cuh3
2 files changed, 31 insertions, 0 deletions
diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu
index a5ff9632..f9e20801 100644
--- a/ggml-cuda/unary.cu
+++ b/ggml-cuda/unary.cu
@@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sqrtf(x[i]);
+}
+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+ sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh
index a1d07c04..4cfb0479 100644
--- a/ggml-cuda/unary.cuh
+++ b/ggml-cuda/unary.cuh
@@ -8,6 +8,7 @@
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);