summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-27 17:40:59 +0300
committerGitHub <noreply@github.com>2024-08-27 17:40:59 +0300
commitc7e99c88a2de7489ba2a1539b1a9025912010b70 (patch)
tree9976409b1e8fac1fc7486f2c5da05a33b8e229b5 /src
parentbd99ed7d0afd2b12c0f5ff5c17b58486396dfe7e (diff)
Faster Gemma2 (#27)
* soft_cap_max: initial CPU version of fused softcap + soft_max With this vanilla CPU implementation I'm already getting a ~3% speedup for Gemma-2-9b and a prompt of 8192 tokens. * soft_cap_max: WIP - something is wrong with CUDA * soft_cap_max: looks good on CPU and CUDA * Add softcap to flash attention Just CPU and CUDA for now (but, as we know, flash attention on the CPU is useless in llama.cpp). On CUDA this improves PP performance quite a bit, especially for long contexts. E.g., for PP-16384, I now get 3777 t/s. Without this change, one cannot use FA, and one gets 2300 t/s (after fusing softcap and softmax), or 2000 t/s without the fused softcap+softmax. In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before PR-8542 (where Johannes Gaessler has also added softcap to FA), and PP-16384 = 3097 t/s after this PR. * soft_cap_max: Metal * Flash attention with softcap: Metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/llama.cpp71
1 files changed, 26 insertions, 45 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 831f98dc..8a85144e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
- cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -8324,10 +8325,12 @@ static struct ggml_tensor * llm_build_kqv(
}
if (hparams.attn_soft_cap) {
- kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ } else {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
-
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv.size == n_ctx);
@@ -13220,47 +13223,31 @@ struct llm_build_context {
0);
cb(k, "k", il);
- if (cparams.flash_attn) {
-
- // split cached v into n_head heads (not transposed)
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
- 0);
- cb(v, "v", il);
-
- cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
-
- cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
- } else {
- struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
- cb(kq, "kq", il);
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
- GGML_ASSERT(kv_self.size == n_ctx);
+ GGML_ASSERT(kv_self.size == n_ctx);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv_self.v_l[il])*n_ctx,
- ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
- 0);
- cb(v, "v", il);
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
- cb(kqv, "kqv", il);
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
- cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
- cb(cur_attn, "kqv_merged_cont", il);
- }
+ cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur_attn, "kqv_merged_cont", il);
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
model.layers[il].attn_sub_norm, NULL,
@@ -16811,12 +16798,6 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
- if (params.flash_attn && model->hparams.attn_soft_cap) {
- LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
- params.flash_attn = false;
- }
-
-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;