summaryrefslogtreecommitdiff
path: root/ggml-vulkan.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-vulkan.cpp')
-rw-r--r--ggml-vulkan.cpp5
1 files changed, 5 insertions, 0 deletions
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
index 1736ab73..f712cdd5 100644
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SOFT_MAX:
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
+
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}