summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-04-30 12:16:08 +0300
committerGitHub <noreply@github.com>2024-04-30 12:16:08 +0300
commit9c67c2773d4b706cf71d70ecf4aa180b62501960 (patch)
treebe51cbda5b15ae1bb3a465a2551e7dbe6d3101d7 /llama.cpp
parent952d03dbead16e4dbdd1d3458486340673cc2465 (diff)
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp564
1 files changed, 367 insertions, 197 deletions
diff --git a/llama.cpp b/llama.cpp
index 72c10ffc..18d6297c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -108,7 +108,6 @@
#define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 60
-
//
// logging
//
@@ -1846,7 +1845,7 @@ struct llama_hparams {
float f_logit_scale = 0.0f;
bool causal_attn = true;
- bool need_kq_pos = false;
+ bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -1936,6 +1935,7 @@ struct llama_cparams {
bool embeddings;
bool causal_attn;
bool offload_kqv;
+ bool flash_attn;
enum llama_pooling_type pooling_type;
@@ -2039,8 +2039,8 @@ struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool do_copy = false;
- // with recurrent state models, a cell can hold the state for more than one past token
- bool recurrent = false;
+ bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
+ bool v_trans = true; // the value tensor is transposed
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
@@ -2339,11 +2339,14 @@ struct llama_context {
static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
- const llama_model & model,
+ const llama_context * ctx,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) {
+ const llama_model & model = ctx->model;
+ const llama_cparams & cparams = ctx->cparams;
+
const struct llama_hparams & hparams = model.hparams;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
@@ -2354,6 +2357,7 @@ static bool llama_kv_cache_init(
// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+ cache.v_trans = !cparams.flash_attn;
// TODO: support mixed reccurent Transformer architectues
// NOTE: (!a || b) is a logical implication (a -> b)
@@ -2566,6 +2570,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
}
cache.head = 0;
cache.used = 0;
+
+ for (auto & buf : cache.bufs) {
+ ggml_backend_buffer_clear(buf, 0);
+ }
}
static bool llama_kv_cache_seq_rm(
@@ -4194,7 +4202,7 @@ static void llm_load_hparams(
model.ftype = ml.ftype;
if (hparams.f_max_alibi_bias > 0.0f) {
- hparams.need_kq_pos = true;
+ hparams.use_alibi = true;
}
hparams.rope_type = llama_rope_type(&model);
@@ -6203,37 +6211,47 @@ static struct ggml_tensor * llm_build_inp_embd(
static void llm_build_kv_store(
struct ggml_context * ctx,
const llama_hparams & hparams,
+ const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
- int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
const llm_build_cb & cb,
int64_t il) {
+ const int64_t n_ctx = cparams.n_ctx;
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(kv.size == n_ctx);
- // compute the transposed [n_tokens, n_embd] V matrix
- assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
- struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur);
- cb(v_cur_t, "v_cur_t", il);
-
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
- struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
- ( n_ctx)*ggml_element_size(kv.v_l[il]),
- (kv_head)*ggml_element_size(kv.v_l[il]));
+ // note: storing RoPE-ed version of K in the KV cache
+ ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
+
+ assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+
+ struct ggml_tensor * v_cache_view = nullptr;
+
+ if (cparams.flash_attn) {
+ v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
+ (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
+ } else {
+ // note: the V cache is transposed when not using flash attention
+ v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
+ ( n_ctx)*ggml_element_size(kv.v_l[il]),
+ (kv_head)*ggml_element_size(kv.v_l[il]));
+
+ v_cur = ggml_transpose(ctx, v_cur);
+ }
cb(v_cache_view, "v_cache_view", il);
- // important: storing RoPE-ed version of K in the KV cache!
- ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
- ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
+ ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
}
static struct ggml_tensor * llm_build_norm(
@@ -6453,11 +6471,11 @@ static struct ggml_tensor * llm_build_moe_ffn(
return moe_out;
}
-// if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx,
const llama_model & model,
const llama_hparams & hparams,
+ const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * wo,
@@ -6465,12 +6483,12 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_pos,
- int64_t n_ctx,
int32_t n_tokens,
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb,
int il) {
+ const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
@@ -6488,71 +6506,99 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(k, "k", il);
- struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
- cb(kq, "kq", il);
+ struct ggml_tensor * cur;
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
- // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
- // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- }
+ if (cparams.flash_attn) {
+ GGML_UNUSED(model);
+ GGML_UNUSED(n_ctx);
- if (model.arch == LLM_ARCH_GROK) {
- // need to do the following:
- // multiply by attn_output_multiplyer of 0.08838834764831845
- // and then :
- // kq = 30 * tanh(kq / 30)
- // before the softmax below
+ // note: if this assert triggers, then some check has failed earlier
+ // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
+ GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
- //try from phi2
- //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ // split cached v into n_head heads (not transposed)
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx, kv.v_l[il],
+ n_embd_head_v, n_kv, n_head_kv,
+ ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+ 0);
+ cb(v, "v", il);
- kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
- kq = ggml_scale(ctx, kq, 30);
- }
+ cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
+
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
+ ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+ }
+
+ cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+ } else {
+ struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+ cb(kq, "kq", il);
+
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
+ // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+ // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ }
+
+ if (model.arch == LLM_ARCH_GROK) {
+ // need to do the following:
+ // multiply by attn_output_multiplyer of 0.08838834764831845
+ // and then :
+ // kq = 30 * tanh(kq / 30)
+ // before the softmax below
+
+ //try from phi2
+ //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+ kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+ kq = ggml_scale(ctx, kq, 30);
+ }
#if defined(GGML_USE_KOMPUTE)
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
- if (hparams.f_max_alibi_bias > 0.0f) {
- kq = ggml_scale(ctx, kq, kq_scale);
- cb(kq, "kq_scaled", il);
+ if (hparams.use_alibi) {
+ kq = ggml_scale(ctx, kq, kq_scale);
+ cb(kq, "kq_scaled", il);
- kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
- cb(kq, "kq_scaled_alibi", il);
+ kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
+ cb(kq, "kq_scaled_alibi", il);
- kq = ggml_add(ctx, kq, kq_mask);
- cb(kq, "kq_masked", il);
+ kq = ggml_add(ctx, kq, kq_mask);
+ cb(kq, "kq_masked", il);
- kq = ggml_soft_max(ctx, kq);
- cb(kq, "kq_soft_max", il);
- } else
+ kq = ggml_soft_max(ctx, kq);
+ cb(kq, "kq_soft_max", il);
+ } else
#endif
- {
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
- }
+ {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+ }
- GGML_ASSERT(kv.size == n_ctx);
+ GGML_ASSERT(kv.size == n_ctx);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx, kv.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv.v_l[il])*n_ctx,
- ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
- 0);
- cb(v, "v", il);
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx, kv.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv.v_l[il])*n_ctx,
+ ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
- cb(kqv, "kqv", il);
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+ cb(kqv, "kqv", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
- struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
- cb(cur, "kqv_merged_cont", il);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+ }
ggml_build_forward_expand(graph, cur);
@@ -6572,6 +6618,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_context * ctx,
const llama_model & model,
const llama_hparams & hparams,
+ const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * wo,
@@ -6581,7 +6628,6 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_pos,
- int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
int32_t n_kv,
@@ -6595,12 +6641,12 @@ static struct ggml_tensor * llm_build_kv(
ggml_build_forward_expand(graph, k_cur);
ggml_build_forward_expand(graph, v_cur);
- llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
+ llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
struct ggml_tensor * cur;
- cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
- q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
+ cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
+ q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il);
return cur;
@@ -6642,6 +6688,8 @@ struct llm_build_context {
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
+ const bool flash_attn;
+
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
@@ -6688,6 +6736,7 @@ struct llm_build_context {
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
+ flash_attn (cparams.flash_attn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
@@ -6802,15 +6851,31 @@ struct llm_build_context {
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
- ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
- nm, n_embd_v_gqa,
- ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
- ggml_row_size(kv_self.v_l[il]->type, i));
+ ggml_tensor * view_v_src;
+ ggml_tensor * view_v_dst;
- ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
- nm, n_embd_v_gqa,
- ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
- ggml_row_size(kv_self.v_l[il]->type, id));
+ if (flash_attn) {
+ // NOTE: the V cache is not transposed when using flash attention
+ view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
+ n_embd_v_gqa, nm,
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
+
+ view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
+ n_embd_v_gqa, nm,
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
+ } else {
+ view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
+ nm, n_embd_v_gqa,
+ ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+ ggml_row_size(kv_self.v_l[il]->type, i));
+
+ view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
+ nm, n_embd_v_gqa,
+ ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+ ggml_row_size(kv_self.v_l[il]->type, id));
+ }
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
@@ -6840,20 +6905,26 @@ struct llm_build_context {
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
+ lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} else {
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+ lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
}
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
- return lctx.inp_KQ_mask;
+ return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
}
- struct ggml_tensor * build_inp_KQ_pos() {
- lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
+ struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
+ if (causal) {
+ lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
+ } else {
+ // TODO: this will be needed for ALiBi-based BERT models
+ // https://github.com/ggerganov/llama.cpp/pull/6826
+ lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
+ }
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
ggml_set_input(lctx.inp_KQ_pos);
- return lctx.inp_KQ_pos;
+ return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
}
struct ggml_tensor * build_inp_mean() {
@@ -6959,9 +7030,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7099,9 +7170,9 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7206,9 +7277,9 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7326,9 +7397,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7451,9 +7522,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
@@ -7603,9 +7674,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
- model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+ model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7715,9 +7786,9 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -7919,9 +7990,9 @@ struct llm_build_context {
);
cb(Vcur, "Vcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8015,9 +8086,9 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8308,9 +8379,9 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8439,14 +8510,15 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
} else {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
}
@@ -8588,9 +8660,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8706,9 +8778,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8819,9 +8891,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -8933,9 +9005,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -9088,9 +9160,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
@@ -9205,9 +9277,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
- model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
@@ -9318,9 +9390,9 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
struct ggml_tensor * sa_out = cur;
@@ -9421,9 +9493,9 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -9528,9 +9600,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -9644,9 +9716,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -9761,9 +9833,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -9891,9 +9963,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -10012,9 +10084,9 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
@@ -10131,9 +10203,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -10421,9 +10493,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -10552,9 +10624,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
- cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, nullptr,
- Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -10981,7 +11053,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
- if (hparams.need_kq_pos) {
+ // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch
+ // this allows to process multiple sequences in parallel with ALiBi-based models
+ if (hparams.use_alibi) {
const int64_t n_kv = kv_self.n;
GGML_ASSERT(lctx.inp_KQ_pos);
@@ -11363,7 +11437,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+ kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
@@ -11560,7 +11634,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
- const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+ //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+ const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
// determine which KV cells to move where
//
@@ -15167,6 +15243,7 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
+ /*.flash_attn =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -15333,6 +15410,7 @@ struct llama_context * llama_new_context_with_model(
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
+ cparams.flash_attn = params.flash_attn;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -15340,12 +15418,20 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// this is necessary due to kv_self.n being padded later during inference
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
- cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+ // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
+ // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
+ // ref: https://github.com/ggerganov/llama.cpp/pull/5021
+ if (cparams.n_batch < GGML_KQ_MASK_PAD) {
+ LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
+ cparams.n_batch = GGML_KQ_MASK_PAD;
+ }
+
+ cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
@@ -15377,6 +15463,23 @@ struct llama_context * llama_new_context_with_model(
}
}
+ if (cparams.flash_attn && hparams.use_alibi) {
+ LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
+ cparams.flash_attn = false;
+ }
+
+ if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+ cparams.flash_attn = false;
+ }
+
+#ifdef GGML_USE_HIPBLAS
+ if (cparams.flash_attn) {
+ LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
+ cparams.flash_attn = false;
+ }
+#endif
+
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
@@ -15384,6 +15487,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
+ LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -15512,7 +15616,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
- if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
+ if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@@ -16111,6 +16215,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
+ const size_t s_v_trans = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
@@ -16128,10 +16233,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
+ s_kv_head
+ s_kv_size
+ s_kv_used
+ + s_v_trans
+ s_kv
+ s_kv_cells
);
+ // on session change it is very likely that the state size has changed - so we need to update this function
+ static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
+
return s_total;
}
@@ -16277,11 +16386,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
const uint32_t kv_size = kv_self.size;
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
+ const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));
+ data_ctx->write(&v_trans, sizeof(v_trans));
if (kv_buf_size) {
const size_t pre_kv_buf_size = data_ctx->get_size_written();
@@ -16294,7 +16405,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
- if (kv_self.recurrent) {
+ if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
@@ -16427,11 +16538,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;
+ uint32_t v_trans;
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
+ memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans);
+
+ GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
if (kv_self.size != kv_size) {
// the KV cache needs to be big enough to load all the KV cells from the saved state
@@ -16441,6 +16556,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
__func__, kv_head, kv_size, kv_self.size);
}
+ llama_kv_cache_clear(ctx);
+
if (kv_buf_size) {
const size_t pre_kv_buf_size = inp - src;
@@ -16452,7 +16569,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
- if (kv_self.recurrent) {
+ if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
@@ -16474,8 +16591,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
- llama_kv_cache_clear(ctx);
-
ctx->kv_self.head = kv_head;
ctx->kv_self.used = kv_used;
@@ -16735,28 +16850,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
}
}
- // For the values, they are transposed, so we also need the element size and get the element ranges from each row
- const uint32_t kv_size = kv_self.size;
- for (int il = 0; il < (int)n_layer; ++il) {
- // Write value type
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- data_ctx.write(&v_type_i, sizeof(v_type_i));
+ // TODO: simplify, reduce copy-paste
+ if (!kv_self.v_trans) {
+ for (int il = 0; il < (int)n_layer; ++il) {
+ // Write value type
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ data_ctx.write(&v_type_i, sizeof(v_type_i));
- // Write element size
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
- data_ctx.write(&v_size_el, sizeof(v_size_el));
+ // Write row size of value
+ const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+ data_ctx.write(&v_size_row, sizeof(v_size_row));
- // For each row, we get the element values of each cell
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- // Read each range of cells of v_size_el length each into tmp_buf and write out
+ // Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
- tmp_buf.resize(range_size * v_size_el);
- ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+ tmp_buf.resize(range_size * v_size_row);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
+ } else {
+ // For the values, they are transposed, so we also need the element size and get the element ranges from each row
+ const uint32_t kv_size = kv_self.size;
+ for (int il = 0; il < (int)n_layer; ++il) {
+ // Write value type
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ data_ctx.write(&v_type_i, sizeof(v_type_i));
+
+ // Write element size
+ const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+ data_ctx.write(&v_size_el, sizeof(v_size_el));
+
+ // For each row, we get the element values of each cell
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
+ for (const auto & range : cell_ranges) {
+ const size_t range_size = range.second - range.first;
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+ tmp_buf.resize(range_size * v_size_el);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+ data_ctx.write(tmp_buf.data(), tmp_buf.size());
+ }
+ }
+ }
}
return data_ctx.get_size_written();
@@ -16881,41 +17017,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
}
}
- // For each layer, read the values for each cell (transposed)
- for (int il = 0; il < (int)n_layer; ++il) {
- // Read type of value
- int32_t v_type_i_ref;
- memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
- inp += sizeof(v_type_i_ref);
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- if (v_type_i != v_type_i_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
- return 0;
- }
+ // TODO: simplify, reduce copy-paste
+ if (!kv_self.v_trans) {
+ for (int il = 0; il < (int)n_layer; ++il) {
+ // Read type of value
+ int32_t v_type_i_ref;
+ memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
+ inp += sizeof(v_type_i_ref);
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ if (v_type_i != v_type_i_ref) {
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+ return 0;
+ }
- // Read element size of value
- size_t v_size_el_ref;
- memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
- inp += sizeof(v_size_el_ref);
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
- if (v_size_el != v_size_el_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
- return 0;
- }
+ // Read row size of value
+ size_t v_size_row_ref;
+ memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
+ inp += sizeof(v_size_row_ref);
+ const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+ if (v_size_row != v_size_row_ref) {
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
+ return 0;
+ }
- if (cell_count) {
- // For each row in the transposed matrix, read the values for the whole cell range
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
- ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
- inp += cell_count * v_size_el;
+ if (cell_count) {
+ // Read and set the values for the whole cell range
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
+ inp += cell_count * v_size_row;
+ }
+ }
+ } else {
+ // For each layer, read the values for each cell (transposed)
+ for (int il = 0; il < (int)n_layer; ++il) {
+ // Read type of value
+ int32_t v_type_i_ref;
+ memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
+ inp += sizeof(v_type_i_ref);
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ if (v_type_i != v_type_i_ref) {
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+ return 0;
+ }
+
+ // Read element size of value
+ size_t v_size_el_ref;
+ memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
+ inp += sizeof(v_size_el_ref);
+ const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+ if (v_size_el != v_size_el_ref) {
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
+ return 0;
+ }
+
+ if (cell_count) {
+ // For each row in the transposed matrix, read the values for the whole cell range
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
+ inp += cell_count * v_size_el;
+ }
}
}
}
const size_t nread = inp - src;
+
return nread;
}