summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-sycl/norm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-sycl/norm.cpp')
-rw-r--r--ggml/src/ggml-sycl/norm.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp
index cccf87d0..b3159b9d 100644
--- a/ggml/src/ggml-sycl/norm.cpp
+++ b/ggml/src/ggml-sycl/norm.cpp
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
}
static void group_norm_f32_sycl(const float* x, float* dst,
- const int num_groups, const int group_size,
+ const int num_groups, const float eps, const int group_size,
const int ne_elements, queue_ptr stream, int device) {
- static const float eps = 1e-6f;
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
GGML_ASSERT(dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
+
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
+
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
(void)src1;
(void)dst;