diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-03 15:17:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-03 15:17:51 +0200 |
commit | a87e54db6ec2409284a55f029d4abe9e50990064 (patch) | |
tree | 920bb8ce4fbd35e54bda3b61a86d0f87c2ac0ede | |
parent | a89adaa78f505675be7be6180f419b4b0158c15a (diff) |
Flash MLA (CPU only) (#240)
* FlashMLA - it finally works (on the CPU)
* FlashMLA: allow for f16 and bf16 cache in addition to q8_0
* It works with ggml FA, not with iqk FA
* WIP
* FlashMLA: it now works with iqk
I had forgotten to divide the Q stride by sizeof(float) and
that's why, very cobfusingly, it was working for TG but not for PP.
* WIP
* FlashMLA: that should be it for now
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-quants.c | 8 | ||||
-rw-r--r-- | ggml/src/ggml.c | 6 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 94 | ||||
-rw-r--r-- | src/llama.cpp | 138 |
4 files changed, 175 insertions, 71 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e8218e76..e39cf4aa 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -11,6 +11,7 @@ #include "ggml-quants.h" #include "ggml-impl.h" #if GGML_USE_IQK_MULMAT +#include "iqk/iqk_config.h" #include "iqk/iqk_mul_mat.h" #include "iqk/iqk_quantize.h" #endif @@ -5449,7 +5450,12 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { +#ifdef HAVE_FANCY_SIMD + enum ggml_type dot_type = GGML_TYPE_Q8_1_X4; +#else + enum ggml_type dot_type = GGML_TYPE_Q8_0_X4; +#endif + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) { return; } #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 31fbc57e..46e1a548 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10451,7 +10451,7 @@ static void ggml_compute_forward_dup_bytes( ne00 == ne0 && nb00 == type_size && nb0 == type_size) { // copy by rows - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -17871,6 +17871,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { + //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", + // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); // I keep changing my mind what is the best strategy to split the threads when processing // multiple heads. This is my current thinking, the commented out code below was the previous. int ntg = nth/simple_gcd(neq2*neq3, nth); @@ -17906,8 +17908,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } return; IQK_Flash_Attn_NotAvailable:; + printf("iqk_flash was rejected\n"); } - #endif const uint32_t n_head = neq2; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0955f15d..1f18837c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15016,7 +15016,7 @@ template <int k_step> struct BaseHelper { BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {} - inline void set_block(int k1) { block = data + k1*k_step*stride; } + //inline void set_block(int k1) { block = data + k1*k_step*stride; } inline void reset_block() { block = data; } inline void next_block() { block += k_step*stride; } inline const char * lblock(int l1) const { return block + l1*stride; } @@ -16038,9 +16038,9 @@ struct FlashQKV { } inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { - //GGML_ASSERT(fms.S[j] > 0); - //auto norm = F16::set1(1/fms.S[j]); - auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); + GGML_ASSERT(fms.S[j] > 0); + auto norm = F16::set1(1/fms.S[j]); + //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -16076,7 +16076,7 @@ struct FlashQKV { template <int D, int q_step, int k_step> struct FlashQKfp32 { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(D%F16::block_size == 0 && D <= 576); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16571,8 +16571,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. template <int Dk, int Dv, int q_step, int k_step> struct FlashAttn { - static_assert(Dk%F16::block_size == 0 && Dk <= 256); - static_assert(Dv%F16::block_size == 0 && Dv <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 576); + static_assert(Dv%F16::block_size == 0 && Dv <= 512); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16665,7 +16665,8 @@ struct HelperBF16 final : public BaseHelper<step> { template <int D, int q_step, int k_step> struct FlashQKbf16 { - static_assert(D%32 == 0 && D <= 256); + //static_assert(D%32 == 0 && D <= 256); + static_assert(D%32 == 0 && D <= 576); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16975,8 +16976,10 @@ struct FlashQKbf16 { template <int Dk, int Dv, int q_step, int k_step> struct FlashAttnBF16 { - static_assert(Dk%32 == 0 && Dk <= 256); - static_assert(Dv%32 == 0 && Dv <= 256); + //static_assert(Dk%32 == 0 && Dk <= 256); + //static_assert(Dv%32 == 0 && Dv <= 256); + static_assert(Dk%32 == 0 && Dk <= 576); + static_assert(Dv%32 == 0 && Dv <= 512); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -17216,6 +17219,66 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif return false; } + +template <int step_k, typename KHelper, typename VHelper> +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float scale, float softcap, float * qkv) { + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + } +} + +template <int step_k> +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60<576, step_k> kh((const char *)k, stride_k); + HelperQ60<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<576, step_k> kh((const char *)k, stride_k); + HelperQ8KV<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_F16) { + HelperF16<576, step_k> kh((const char *)k, stride_k); + HelperF16<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<576, step_k> kh((const char *)k, stride_k); + HelperBF16<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } +#endif + return false; +} + } bool iqk_flash_attn_noalibi(int int_type_k, // type of k @@ -17237,10 +17300,19 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv) { // v*softmax(scale*(k*q)) + if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + auto type_k = ggml_type(int_type_k); auto type_v = ggml_type(int_type_v); + + if (Dk == 576 && Dv == 512) { + GGML_ASSERT(type_k == type_v); + stride_q /= sizeof(float); // q stride as float + return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv); + } + if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; - if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 if (Dk != Dv && Dk != 192 && Dv != 128) return false; if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; diff --git a/src/llama.cpp b/src/llama.cpp index 3a8b54ca..5ac44055 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3168,7 +3168,7 @@ static bool llama_kv_cache_init( // DeepSeek MLA cache.kv_l.reserve(n_layer); - if (cparams.mla_attn == 1) { + if (cparams.mla_attn == 1 && !cparams.flash_attn) { cache.kvt_l.reserve(n_layer); } @@ -3201,14 +3201,20 @@ static bool llama_kv_cache_init( const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); - auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; - ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_kv_l%d", i); - cache.kv_l.push_back(kv); - if (cparams.mla_attn == 1) { - ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kvt_l.push_back(kvt); + if (cparams.flash_attn) { + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + } else { + auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; + ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + if (cparams.mla_attn == 1) { + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); + } } n_mla++; } @@ -13588,7 +13594,7 @@ struct llm_build_context { ggml_tensor * kv_cache_trans; - if (lctx.cparams.mla_attn == 1) { + if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) { ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); @@ -13630,70 +13636,88 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kqv_compressed; + + if (lctx.cparams.flash_attn) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache, "kv_cache_lora", il); + cb(kv_cache_lora, "kv_cache_lora", il); - kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); - cb(kv_cache_trans, "kv_cache_trans", il); - } + //ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora); + //kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - ggml_tensor * kqv_compressed; - auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB - if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); + kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv_compressed, "kqv_compressed", il); + + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); + + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); } - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); + } - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); - kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); - cb(kqv_compressed, "kqv_compressed", il); + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); - } + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); - } else { + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } - int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; - n_step = std::min(n_step, int(q->ne[2])); - int n_per_step = (q->ne[2] + n_step - 1)/n_step; + } else { - //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; - for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { - int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; - ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); - ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); - kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); - if (i_head == 0) { - kqv_compressed = kqv_i; - } else { - kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); } - ggml_build_forward_expand(gf, kqv_compressed); + cb(kqv_compressed, "kqv_compressed", il); } - cb(kqv_compressed, "kqv_compressed", il); } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, @@ -18226,7 +18250,7 @@ struct llama_context * llama_new_context_with_model( } if (memory_size_kv + memory_size_kvt > 0) { - if (cparams.mla_attn == 1) { + if (cparams.mla_attn == 1 && !cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), |