diff options
Diffstat (limited to 'ggml-kompute.cpp')
-rw-r--r-- | ggml-kompute.cpp | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6..9a469821 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: |