summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-03 15:17:51 +0200
committerGitHub <noreply@github.com>2025-03-03 15:17:51 +0200
commita87e54db6ec2409284a55f029d4abe9e50990064 (patch)
tree920bb8ce4fbd35e54bda3b61a86d0f87c2ac0ede
parenta89adaa78f505675be7be6180f419b4b0158c15a (diff)
Flash MLA (CPU only) (#240)
* FlashMLA - it finally works (on the CPU) * FlashMLA: allow for f16 and bf16 cache in addition to q8_0 * It works with ggml FA, not with iqk FA * WIP * FlashMLA: it now works with iqk I had forgotten to divide the Q stride by sizeof(float) and that's why, very cobfusingly, it was working for TG but not for PP. * WIP * FlashMLA: that should be it for now --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-quants.c8
-rw-r--r--ggml/src/ggml.c6
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp94
-rw-r--r--src/llama.cpp138
4 files changed, 175 insertions, 71 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index e8218e76..e39cf4aa 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -11,6 +11,7 @@
#include "ggml-quants.h"
#include "ggml-impl.h"
#if GGML_USE_IQK_MULMAT
+#include "iqk/iqk_config.h"
#include "iqk/iqk_mul_mat.h"
#include "iqk/iqk_quantize.h"
#endif
@@ -5449,7 +5450,12 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
- if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
+#ifdef HAVE_FANCY_SIMD
+ enum ggml_type dot_type = GGML_TYPE_Q8_1_X4;
+#else
+ enum ggml_type dot_type = GGML_TYPE_Q8_0_X4;
+#endif
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) {
return;
}
#endif
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 31fbc57e..46e1a548 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -10451,7 +10451,7 @@ static void ggml_compute_forward_dup_bytes(
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
- const size_t rs = ne00 * type_size;
+ const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -17871,6 +17871,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
#if GGML_USE_IQK_MULMAT
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
+ //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
+ // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
// I keep changing my mind what is the best strategy to split the threads when processing
// multiple heads. This is my current thinking, the commented out code below was the previous.
int ntg = nth/simple_gcd(neq2*neq3, nth);
@@ -17906,8 +17908,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
return;
IQK_Flash_Attn_NotAvailable:;
+ printf("iqk_flash was rejected\n");
}
-
#endif
const uint32_t n_head = neq2;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 0955f15d..1f18837c 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -15016,7 +15016,7 @@ template <int k_step>
struct BaseHelper {
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
- inline void set_block(int k1) { block = data + k1*k_step*stride; }
+ //inline void set_block(int k1) { block = data + k1*k_step*stride; }
inline void reset_block() { block = data; }
inline void next_block() { block += k_step*stride; }
inline const char * lblock(int l1) const { return block + l1*stride; }
@@ -16038,9 +16038,9 @@ struct FlashQKV {
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
- //GGML_ASSERT(fms.S[j] > 0);
- //auto norm = F16::set1(1/fms.S[j]);
- auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
+ GGML_ASSERT(fms.S[j] > 0);
+ auto norm = F16::set1(1/fms.S[j]);
+ //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@@ -16076,7 +16076,7 @@ struct FlashQKV {
template <int D, int q_step, int k_step>
struct FlashQKfp32 {
- static_assert(D%F16::block_size == 0 && D <= 256);
+ static_assert(D%F16::block_size == 0 && D <= 576);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -16571,8 +16571,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttn {
- static_assert(Dk%F16::block_size == 0 && Dk <= 256);
- static_assert(Dv%F16::block_size == 0 && Dv <= 256);
+ static_assert(Dk%F16::block_size == 0 && Dk <= 576);
+ static_assert(Dv%F16::block_size == 0 && Dv <= 512);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -16665,7 +16665,8 @@ struct HelperBF16 final : public BaseHelper<step> {
template <int D, int q_step, int k_step>
struct FlashQKbf16 {
- static_assert(D%32 == 0 && D <= 256);
+ //static_assert(D%32 == 0 && D <= 256);
+ static_assert(D%32 == 0 && D <= 576);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -16975,8 +16976,10 @@ struct FlashQKbf16 {
template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttnBF16 {
- static_assert(Dk%32 == 0 && Dk <= 256);
- static_assert(Dv%32 == 0 && Dv <= 256);
+ //static_assert(Dk%32 == 0 && Dk <= 256);
+ //static_assert(Dv%32 == 0 && Dv <= 256);
+ static_assert(Dk%32 == 0 && Dk <= 576);
+ static_assert(Dv%32 == 0 && Dv <= 512);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -17216,6 +17219,66 @@ inline bool flash_attn_is_supported(ggml_type type) {
#endif
return false;
}
+
+template <int step_k, typename KHelper, typename VHelper>
+inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
+ int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float scale, float softcap, float * qkv) {
+ if (nq1 % 8 == 0) {
+ FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
+ } else {
+ FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
+ }
+}
+
+template <int step_k>
+inline bool iqk_deepseek_helper(ggml_type type_k,
+ int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv) {
+ if (type_k == GGML_TYPE_Q8_0) {
+ HelperQ80<576, step_k> kh((const char *)k, stride_k);
+ HelperQ80<512, step_k> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ return true;
+ }
+ if (type_k == GGML_TYPE_Q6_0) {
+ HelperQ60<576, step_k> kh((const char *)k, stride_k);
+ HelperQ60<512, step_k> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ return true;
+ }
+ if (type_k == GGML_TYPE_Q8_KV) {
+ HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
+ HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ return true;
+ }
+ if (type_k == GGML_TYPE_F16) {
+ HelperF16<576, step_k> kh((const char *)k, stride_k);
+ HelperF16<512, step_k> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ return true;
+ }
+#ifdef __AVX512BF16__
+ if (type_k == GGML_TYPE_BF16) {
+ HelperBF16<576, step_k> kh((const char *)k, stride_k);
+ HelperBF16<512, step_k> vh((const char *)v, stride_v);
+ if (nq1 % 8 == 0) {
+ FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ } else {
+ FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ }
+ return true;
+ }
+#endif
+ return false;
+}
+
}
bool iqk_flash_attn_noalibi(int int_type_k, // type of k
@@ -17237,10 +17300,19 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv) { // v*softmax(scale*(k*q))
+ if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
+
auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);
+
+ if (Dk == 576 && Dv == 512) {
+ GGML_ASSERT(type_k == type_v);
+ stride_q /= sizeof(float); // q stride as float
+ return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
+ q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv);
+ }
+
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
- if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
if (Dk != Dv && Dk != 192 && Dv != 128) return false;
if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;
diff --git a/src/llama.cpp b/src/llama.cpp
index 3a8b54ca..5ac44055 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3168,7 +3168,7 @@ static bool llama_kv_cache_init(
// DeepSeek MLA
cache.kv_l.reserve(n_layer);
- if (cparams.mla_attn == 1) {
+ if (cparams.mla_attn == 1 && !cparams.flash_attn) {
cache.kvt_l.reserve(n_layer);
}
@@ -3201,14 +3201,20 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
- auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
- ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
- ggml_format_name(kv, "cache_kv_l%d", i);
- cache.kv_l.push_back(kv);
- if (cparams.mla_attn == 1) {
- ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
- ggml_format_name(kvt, "cache_kvt_l%d", i);
- cache.kvt_l.push_back(kvt);
+ if (cparams.flash_attn) {
+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
+ ggml_format_name(kv, "cache_kv_l%d", i);
+ cache.kv_l.push_back(kv);
+ } else {
+ auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
+ ggml_format_name(kv, "cache_kv_l%d", i);
+ cache.kv_l.push_back(kv);
+ if (cparams.mla_attn == 1) {
+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
+ ggml_format_name(kvt, "cache_kvt_l%d", i);
+ cache.kvt_l.push_back(kvt);
+ }
}
n_mla++;
}
@@ -13588,7 +13594,7 @@ struct llm_build_context {
ggml_tensor * kv_cache_trans;
- if (lctx.cparams.mla_attn == 1) {
+ if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
@@ -13630,70 +13636,88 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
- if (lctx.cparams.mla_attn > 1) {
+ ggml_tensor * kqv_compressed;
+
+ if (lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
- cb(kv_cache, "kv_cache_lora", il);
+ cb(kv_cache_lora, "kv_cache_lora", il);
- kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
- cb(kv_cache_trans, "kv_cache_trans", il);
- }
+ //ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora);
+ //kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
- ggml_tensor * kqv_compressed;
- auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
- if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
- if (!pp_opt) {
- q = ggml_permute(ctx0, q, 0, 2, 1, 3);
- cb(q, "q_perm", il);
+ kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
+ cb(kqv_compressed, "kqv_compressed", il);
+
+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
+ cb(kqv_compressed, "kqv_compressed_perm", il);
+ }
+ else {
+ if (lctx.cparams.mla_attn > 1) {
+ ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ cb(kv_cache, "kv_cache_lora", il);
+
+ kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
+ cb(kv_cache_trans, "kv_cache_trans", il);
}
- ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
- cb(kq, "kq", il);
+ auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
+ if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
+ if (!pp_opt) {
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+ cb(q, "q_perm", il);
+ }
- if (!pp_opt) {
- kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
- cb(kq, "kq_perm", il);
- }
+ ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
+ cb(kq, "kq", il);
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
+ if (!pp_opt) {
+ kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
+ cb(kq, "kq_perm", il);
+ }
- if (!pp_opt) {
- kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
- cb(kq, "kq_soft_max_ext_perm", il);
- }
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
- kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
- cb(kqv_compressed, "kqv_compressed", il);
+ if (!pp_opt) {
+ kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
+ cb(kq, "kq_soft_max_ext_perm", il);
+ }
- if (!pp_opt) {
- kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
- cb(kqv_compressed, "kqv_compressed_perm", il);
- }
+ kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
+ cb(kqv_compressed, "kqv_compressed", il);
- } else {
+ if (!pp_opt) {
+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
+ cb(kqv_compressed, "kqv_compressed_perm", il);
+ }
- int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
- n_step = std::min(n_step, int(q->ne[2]));
- int n_per_step = (q->ne[2] + n_step - 1)/n_step;
+ } else {
- //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
+ int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
+ n_step = std::min(n_step, int(q->ne[2]));
+ int n_per_step = (q->ne[2] + n_step - 1)/n_step;
- for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
- int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
- ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
- ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
- kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
- if (i_head == 0) {
- kqv_compressed = kqv_i;
- } else {
- kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
+ //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
+
+ for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
+ int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
+ ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
+ ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
+ kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
+ if (i_head == 0) {
+ kqv_compressed = kqv_i;
+ } else {
+ kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
+ }
+ ggml_build_forward_expand(gf, kqv_compressed);
}
- ggml_build_forward_expand(gf, kqv_compressed);
+ cb(kqv_compressed, "kqv_compressed", il);
}
- cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -18226,7 +18250,7 @@ struct llama_context * llama_new_context_with_model(
}
if (memory_size_kv + memory_size_kvt > 0) {
- if (cparams.mla_attn == 1) {
+ if (cparams.mla_attn == 1 && !cparams.flash_attn) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),