diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/llama.cpp | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index c11affb6..0a81f2b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10136,6 +10136,12 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); +#ifdef GGML_USE_VULKAN + constexpr bool use_f32_precision = true; +#else + constexpr bool use_f32_precision = false; +#endif + struct ggml_tensor * cur; if (cparams.flash_attn) { @@ -10157,7 +10163,7 @@ static struct ggml_tensor * llm_build_kqv( // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || + if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } @@ -10182,7 +10188,7 @@ static struct ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || + if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 @@ -15449,6 +15455,11 @@ struct llm_build_context { } struct ggml_cgraph * build_deepseek2() { +#ifdef GGML_USE_VULKAN + constexpr bool use_f32_attn_precision = true; +#else + constexpr bool use_f32_attn_precision = false; +#endif struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -15678,7 +15689,7 @@ struct llm_build_context { q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter); kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - if (q->ne[1] <= 8) { + if (use_f32_attn_precision || q->ne[1] <= 8) { ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); } cb(kqv, "kqv", il); @@ -15720,6 +15731,10 @@ struct llm_build_context { kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); cb(kqv_compressed, "kqv_compressed", il); + if (use_f32_attn_precision) { + ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32); + } + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); cb(kqv_compressed, "kqv_compressed_perm", il); } |