summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda/norm.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda/norm.cu')
-rw-r--r--ggml/src/ggml-cuda/norm.cu9
1 files changed, 6 insertions, 3 deletions
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 30866d51..133e219f 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
}
}
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
- static const float eps = 1e-6f;
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
+
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
+
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
+ group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
}
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {