summaryrefslogtreecommitdiff
path: root/src/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llama.cpp')
-rw-r--r--src/llama.cpp71
1 files changed, 26 insertions, 45 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 831f98dc..8a85144e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
- cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -8324,10 +8325,12 @@ static struct ggml_tensor * llm_build_kqv(
}
if (hparams.attn_soft_cap) {
- kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ } else {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
-
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv.size == n_ctx);
@@ -13220,47 +13223,31 @@ struct llm_build_context {
0);
cb(k, "k", il);
- if (cparams.flash_attn) {
-
- // split cached v into n_head heads (not transposed)
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
- 0);
- cb(v, "v", il);
-
- cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
-
- cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
- } else {
- struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
- cb(kq, "kq", il);
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
- GGML_ASSERT(kv_self.size == n_ctx);
+ GGML_ASSERT(kv_self.size == n_ctx);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv_self.v_l[il])*n_ctx,
- ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
- 0);
- cb(v, "v", il);
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
- cb(kqv, "kqv", il);
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
- cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
- cb(cur_attn, "kqv_merged_cont", il);
- }
+ cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur_attn, "kqv_merged_cont", il);
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
model.layers[il].attn_sub_norm, NULL,
@@ -16811,12 +16798,6 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
- if (params.flash_attn && model->hparams.attn_soft_cap) {
- LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
- params.flash_attn = false;
- }
-
-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;