summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>2024-02-03 08:11:37 +0000
committerGitHub <noreply@github.com>2024-02-03 16:11:37 +0800
commita305dba8ff642e57f538f42010868fe0bc5262a1 (patch)
treeebf4b7242afe82d9b3cb551e821676d3fefdbf4b
parent191221178f51b6e81122c5bda0fd79620e547d07 (diff)
Fix im2col with 32fp (#5286)
-rw-r--r--ggml-sycl.cpp17
1 files changed, 11 insertions, 6 deletions
diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp
index ac75f8e1..51445b5e 100644
--- a/ggml-sycl.cpp
+++ b/ggml-sycl.cpp
@@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
-static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
+template <typename T>
+static void im2col_kernel(const float *x, T *dst, int offset_delta,
int IW, int IH, int OW, int KW, int KH,
int pelements, int CHW, int s0, int s1, int p0,
int p1, int d0, int d1,
@@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
});
}
-static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
+template <typename T>
+static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC,
int offset_delta, int s0, int s1, int p0,
int p1, int d0, int d1,
@@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH,
+ im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
@@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
- IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ if (dst->type == GGML_TYPE_F16) {
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ } else {
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ }
(void) src0;
(void) src0_dd;