summaryrefslogtreecommitdiff
path: root/ggml-vulkan.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-11 10:32:41 +0300
committerGitHub <noreply@github.com>2024-05-11 10:32:41 +0300
commit9cb317f77e53067f7a138cc89ef7657148eae8e6 (patch)
tree3ba1d2d80d1d7c8b4ab01f6396a3febaae26e91b /ggml-vulkan.cpp
parente849648888a11de13aaaa4cb2eda3f5a9c7b444d (diff)
ggml : full ALiBi support (#7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
Diffstat (limited to 'ggml-vulkan.cpp')
-rw-r--r--ggml-vulkan.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
index 95f71897..b9449be0 100644
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,