summaryrefslogtreecommitdiff
path: root/ggml-kompute.cpp
diff options
context:
space:
mode:
authorwoachk <24752637+woachk@users.noreply.github.com>2024-06-03 07:32:16 +0200
committerGitHub <noreply@github.com>2024-06-03 08:32:16 +0300
commit9e405b6e2ecb888e860f7b92720b4809e21b3915 (patch)
treee42d87168def2b7a21a003253720a79d743f4c3b /ggml-kompute.cpp
parent3413ae2193d0693f14bead02e5018f442cbf579b (diff)
kompute : implement op_getrows_f32 (#6403)
op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
Diffstat (limited to 'ggml-kompute.cpp')
-rw-r--r--ggml-kompute.cpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp
index 0c51c322..eabd70d5 100644
--- a/ggml-kompute.cpp
+++ b/ggml-kompute.cpp
@@ -22,6 +22,7 @@
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
+#include "shaderop_getrows_f32.h"
#include "shaderop_getrows_f16.h"
#include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h"
@@ -1147,6 +1148,14 @@ static void ggml_vk_get_rows(
}
template <typename... Args>
+static void ggml_vk_get_rows_f32(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
+ kp::shader_data::op_getrows_f32_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
static void ggml_vk_get_rows_f16(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
kp::shader_data::op_getrows_f16_comp_spv_len);
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
return op->ne[3] == 1;
case GGML_OP_GET_ROWS:
switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_GET_ROWS:
{
- if (src0t == GGML_TYPE_F16) {
+ if (src0t == GGML_TYPE_F32) {
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_F16) {
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
} else if (src0t == GGML_TYPE_Q4_0) {
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));