summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustina Cho <justcho5@gmail.com>2024-05-01 14:44:26 -0700
committerGeorgi Gerganov <ggerganov@gmail.com>2024-05-11 15:38:34 +0300
commitf5ef34e428f3886544590ecb2d532e4d333c114c (patch)
tree55fccc222e344916be2574642656f05649be3344
parentef0d5e3ec9f99003af3ff326384816c02850ea3f (diff)
feat: implemented sigmoid function (ggml/806)
* added sigmoid function * implemented metal kernel for sigmoid * implemented cuda kernel for sigmoid * added sigmoid unary op and incremented count
-rw-r--r--ggml-cuda.cu4
-rw-r--r--ggml-cuda/unary.cu26
-rw-r--r--ggml-cuda/unary.cuh3
-rw-r--r--ggml-metal.m15
-rw-r--r--ggml-metal.metal7
-rw-r--r--ggml.c73
-rw-r--r--ggml.h9
7 files changed, 136 insertions, 1 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index c5c77879..5b6c9091 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2204,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_RELU:
ggml_cuda_op_relu(ctx, dst);
break;
+ case GGML_UNARY_OP_SIGMOID:
+ ggml_cuda_op_sigmoid(ctx, dst);
+ break;
case GGML_UNARY_OP_HARDSIGMOID:
ggml_cuda_op_hardsigmoid(ctx, dst);
break;
@@ -2716,6 +2719,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu
index 1a7f0946..ac03d5c6 100644
--- a/ggml-cuda/unary.cu
+++ b/ggml-cuda/unary.cu
@@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}
+static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = 1.0f / (1.0f + expf(-x[i]));
+}
+
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
+ sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh
index 2002ed98..a1d07c04 100644
--- a/ggml-cuda/unary.cuh
+++ b/ggml-cuda/unary.cuh
@@ -4,6 +4,7 @@
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
@@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml-metal.m b/ggml-metal.m
index 1bbb8fb4..66c398d5 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -493,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -730,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
@@ -1239,6 +1242,18 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_UNARY_OP_GELU:
{
int64_t n = ggml_nelements(dst);
diff --git a/ggml-metal.metal b/ggml-metal.metal
index ee9de57a..0c6d3279 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -229,6 +229,13 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}
+kernel void kernel_sigmoid(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
kernel void kernel_tanh(
device const float * src0,
device float * dst,
diff --git a/ggml.c b/ggml.c
index 4ee5d24a..4f301158 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1949,6 +1949,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
// TODO: optimize performance
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2329,6 +2330,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TANH",
"ELU",
"RELU",
+ "SIGMOID",
"GELU",
"GELU_QUICK",
"SILU",
@@ -2336,7 +2338,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
};
-static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
+static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -4561,6 +4563,20 @@ struct ggml_tensor * ggml_leaky_relu(
return result;
}
+// ggml_sigmoid
+
+struct ggml_tensor * ggml_sigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+struct ggml_tensor * ggml_sigmoid_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
// ggml_gelu
struct ggml_tensor * ggml_gelu(
@@ -10852,6 +10868,52 @@ static void ggml_compute_forward_relu(
}
}
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sigmoid_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sigmoid(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sigmoid_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_gelu
static void ggml_compute_forward_gelu_f32(
@@ -16617,6 +16679,10 @@ static void ggml_compute_forward_unary(
{
ggml_compute_forward_relu(params, dst);
} break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ ggml_compute_forward_sigmoid(params, dst);
+ } break;
case GGML_UNARY_OP_GELU:
{
ggml_compute_forward_gelu(params, dst);
@@ -18601,6 +18667,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_UNARY_OP_GELU:
{
GGML_ASSERT(false); // TODO: not implemented
@@ -19130,6 +19200,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
{
diff --git a/ggml.h b/ggml.h
index 76c33283..3fe95ed5 100644
--- a/ggml.h
+++ b/ggml.h
@@ -519,6 +519,7 @@ extern "C" {
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
+ GGML_UNARY_OP_SIGMOID,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
@@ -1073,6 +1074,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_sigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);