summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-sycl
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-sycl')
-rw-r--r--ggml/src/ggml-sycl/backend.hpp2
-rw-r--r--ggml/src/ggml-sycl/common.hpp2
-rw-r--r--ggml/src/ggml-sycl/conv.cpp99
-rw-r--r--ggml/src/ggml-sycl/conv.hpp21
-rw-r--r--ggml/src/ggml-sycl/dmmv.cpp2
-rw-r--r--ggml/src/ggml-sycl/dpct/helper.hpp21
-rw-r--r--ggml/src/ggml-sycl/mmq.cpp22
-rw-r--r--ggml/src/ggml-sycl/mmvq.cpp4
-rw-r--r--ggml/src/ggml-sycl/norm.cpp9
-rw-r--r--ggml/src/ggml-sycl/presets.hpp2
-rw-r--r--ggml/src/ggml-sycl/rope.cpp4
-rw-r--r--ggml/src/ggml-sycl/tsembd.cpp71
-rw-r--r--ggml/src/ggml-sycl/tsembd.hpp21
13 files changed, 256 insertions, 24 deletions
diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp
index 067181de..58dd9c9a 100644
--- a/ggml/src/ggml-sycl/backend.hpp
+++ b/ggml/src/ggml-sycl/backend.hpp
@@ -15,6 +15,7 @@
#include "concat.hpp"
#include "common.hpp"
+#include "conv.hpp"
#include "convert.hpp"
#include "dequantize.hpp"
#include "dmmv.hpp"
@@ -23,5 +24,6 @@
#include "rope.hpp"
#include "norm.hpp"
#include "softmax.hpp"
+#include "tsembd.hpp"
#endif // GGML_SYCL_BACKEND_HPP
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 397bd98d..86d8b40e 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -100,7 +100,7 @@ static void crash() {
const char* msg) {
fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg);
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
- GGML_ASSERT(!"SYCL error");
+ GGML_ABORT("SYCL error");
}
#define SYCL_CHECK(err) \
diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp
new file mode 100644
index 00000000..bc4ab1dd
--- /dev/null
+++ b/ggml/src/ggml-sycl/conv.cpp
@@ -0,0 +1,99 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "conv.hpp"
+
+static void conv_transpose_1d_kernel(
+ const int s0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
+ const int src1_ne0, const int dst_ne0,
+ const float * src0, const float * src1, float * dst,
+ const sycl::nd_item<3> &item_ct1) {
+ int global_index = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (global_index >= output_size) {
+ return;
+ }
+
+ int out_index = global_index / dst_ne0;
+
+ float accumulator = 0;
+
+ for (int c = 0; c < src0_ne2; c++) {
+ int idx = global_index % dst_ne0;
+
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+ int input_offset = src1_ne0 * c;
+
+ for (int i = 0; i < src1_ne0; i++) {
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+ continue;
+ }
+ int weight_idx = idx - i*s0;
+
+ float kernel_weight = src0[kernel_offset + weight_idx];
+ float input_value = src1[input_offset+i];
+
+ accumulator += kernel_weight * input_value;
+ }
+ }
+ dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_sycl(
+ const int s0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
+ const int src1_ne0, const int dst_ne0,
+ const float *src0, const float *src1, float *dst,
+ const queue_ptr& stream) {
+
+ const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
+ const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
+ const sycl::range<3> block_nums(1, 1, num_blocks);
+ stream->parallel_for(
+ sycl::nd_range<3>(
+ block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ conv_transpose_1d_kernel(
+ s0, output_size,
+ src0_ne0, src0_ne1, src0_ne2,
+ src1_ne0, dst_ne0,
+ src0, src1, dst, item_ct1);
+ });
+}
+
+void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst) {
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+ dpct::queue_ptr stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+
+ const int s0 = opts[0];
+
+ const int64_t output_size = ggml_nelements(dst);
+
+ conv_transpose_1d_f32_f32_sycl(s0, output_size,
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ src1->ne[0], dst->ne[0],
+ src0_d, src1_d, dst_d, stream);
+}
+
diff --git a/ggml/src/ggml-sycl/conv.hpp b/ggml/src/ggml-sycl/conv.hpp
new file mode 100644
index 00000000..eb20730f
--- /dev/null
+++ b/ggml/src/ggml-sycl/conv.hpp
@@ -0,0 +1,21 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_CONV_HPP
+#define GGML_SYCL_CONV_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst);
+
+#endif // GGML_SYCL_CONV_HPP
diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp
index 70a94fc1..ae45630e 100644
--- a/ggml/src/ggml-sycl/dmmv.cpp
+++ b/ggml/src/ggml-sycl/dmmv.cpp
@@ -1011,7 +1011,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
break;
default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
break;
}
diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp
index 4aaa76bf..fe4a8f74 100644
--- a/ggml/src/ggml-sycl/dpct/helper.hpp
+++ b/ggml/src/ggml-sycl/dpct/helper.hpp
@@ -874,7 +874,7 @@ namespace dpct
inline std::string get_preferred_gpu_platform_name() {
std::string result;
- std::string filter = "level-zero";
+ std::string filter = "";
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
if (env) {
if (std::strstr(env, "level_zero")) {
@@ -892,11 +892,24 @@ namespace dpct
else {
throw std::runtime_error("invalid device filter: " + std::string(env));
}
+ } else {
+ auto default_device = sycl::device(sycl::default_selector_v);
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
+
+ if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
+ filter = "level-zero";
+ }
+ else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
+ filter = "cuda";
+ }
+ else if (std::strstr(default_platform_name.c_str(), "HIP")) {
+ filter = "hip";
+ }
}
- auto plaform_list = sycl::platform::get_platforms();
+ auto platform_list = sycl::platform::get_platforms();
- for (const auto& platform : plaform_list) {
+ for (const auto& platform : platform_list) {
auto devices = platform.get_devices();
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
return d.is_gpu();
@@ -975,7 +988,7 @@ namespace dpct
if (backend == "opencl:cpu") return 4;
if (backend == "opencl:acc") return 5;
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
static bool compare_backend(std::string &backend1, std::string &backend2) {
return convert_backend_index(backend1) < convert_backend_index(backend2);
diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp
index 3107ba91..e952533d 100644
--- a/ggml/src/ggml-sycl/mmq.cpp
+++ b/ggml/src/ggml-sycl/mmq.cpp
@@ -1799,7 +1799,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q4_0_PASCAL;
nwarps = NWARPS_Q4_0_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -1914,7 +1914,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q4_1_PASCAL;
nwarps = NWARPS_Q4_1_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2029,7 +2029,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q5_0_PASCAL;
nwarps = NWARPS_Q5_0_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2144,7 +2144,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q5_1_PASCAL;
nwarps = NWARPS_Q5_1_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2259,7 +2259,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q8_0_PASCAL;
nwarps = NWARPS_Q8_0_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2374,7 +2374,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q2_K_PASCAL;
nwarps = NWARPS_Q2_K_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2497,7 +2497,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q3_K_PASCAL;
nwarps = NWARPS_Q3_K_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2625,7 +2625,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q4_K_PASCAL;
nwarps = NWARPS_Q4_K_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2746,7 +2746,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q5_K_PASCAL;
nwarps = NWARPS_Q5_K_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2867,7 +2867,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
mmq_y = MMQ_Y_Q6_K_PASCAL;
nwarps = NWARPS_Q6_K_PASCAL;
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@@ -3016,7 +3016,7 @@ void ggml_sycl_op_mul_mat_q(
ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
break;
default:
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
break;
}
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
index 3fbc4dd6..1b96925e 100644
--- a/ggml/src/ggml-sycl/mmvq.cpp
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -902,7 +902,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
- mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
@@ -1017,7 +1017,7 @@ void ggml_sycl_op_mul_mat_vec_q(
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
break;
}
}
diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp
index cccf87d0..b3159b9d 100644
--- a/ggml/src/ggml-sycl/norm.cpp
+++ b/ggml/src/ggml-sycl/norm.cpp
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
}
static void group_norm_f32_sycl(const float* x, float* dst,
- const int num_groups, const int group_size,
+ const int num_groups, const float eps, const int group_size,
const int ne_elements, queue_ptr stream, int device) {
- static const float eps = 1e-6f;
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
GGML_ASSERT(dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
+
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
+
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
- group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
(void)src1;
(void)dst;
diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp
index 15ddcac1..340ab8e9 100644
--- a/ggml/src/ggml-sycl/presets.hpp
+++ b/ggml/src/ggml-sycl/presets.hpp
@@ -41,6 +41,8 @@
#define SYCL_ACC_BLOCK_SIZE 256
#define SYCL_IM2COL_BLOCK_SIZE 256
#define SYCL_POOL2D_BLOCK_SIZE 256
+#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_SYCL_DMMV_X
diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp
index 6f507941..c7545bcc 100644
--- a/ggml/src/ggml-sycl/rope.cpp
+++ b/ggml/src/ggml-sycl/rope.cpp
@@ -251,7 +251,7 @@ void ggml_sycl_op_rope(
attn_factor, corr_dims, freq_factors, main_stream
);
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
@@ -265,7 +265,7 @@ void ggml_sycl_op_rope(
attn_factor, corr_dims, freq_factors, main_stream
);
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
}
diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp
new file mode 100644
index 00000000..d5c227cd
--- /dev/null
+++ b/ggml/src/ggml-sycl/tsembd.cpp
@@ -0,0 +1,71 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "tsembd.hpp"
+
+static void timestep_embedding_f32(
+ const float * timesteps, float * dst, const int nb1,
+ const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
+ // item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
+ // item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
+ int i = item_ct1.get_group(1);
+ int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ float * embed_data = (float *)((char *)dst + i*nb1);
+
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
+ embed_data[dim] = 0.f;
+ }
+
+ int half = dim / 2;
+ if (j >= half) {
+ return;
+ }
+
+ float timestep = timesteps[i];
+ float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
+ float arg = timestep * freq;
+ embed_data[j] = sycl::cos(arg);
+ embed_data[j + half] = sycl::sin(arg);
+}
+
+static void timestep_embedding_f32_sycl(
+ const float * x, float * dst, const int ne00, const int nb1,
+ const int dim, const int max_period, const queue_ptr& stream) {
+ // As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
+ int half_ceil = dim / 2;
+ int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
+ sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
+ sycl::range<3> gridDim(1, ne00, num_blocks);
+ stream->parallel_for(
+ sycl::nd_range<3>(
+ gridDim * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ timestep_embedding_f32(
+ x, dst, nb1, dim, max_period, item_ct1
+ );
+ });
+}
+
+void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor * dst) {
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ dpct::queue_ptr stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+
+ timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
+}
diff --git a/ggml/src/ggml-sycl/tsembd.hpp b/ggml/src/ggml-sycl/tsembd.hpp
new file mode 100644
index 00000000..ff854c33
--- /dev/null
+++ b/ggml/src/ggml-sycl/tsembd.hpp
@@ -0,0 +1,21 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_TSEMBD_HPP
+#define GGML_SYCL_TSEMBD_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor * dst);
+
+#endif // GGML_SYCL_TSEMBD_HPP