diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-11 14:46:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-11 14:46:30 +0200 |
commit | 3c98bfb33d149a0d9d3bb91604dd12709721e3cf (patch) | |
tree | 6a1e5fc373032bb18a62ec3616625eedf1a9f1f3 | |
parent | a366a3d17d8f2de0eb8c3d9eddc7b5840fb5761a (diff) |
DeepSeek FA support (CPU only) (#200)
* Adding support for K head size != V head size
This is relevant for DeepSeek models.
At this point ggml CPU FA works.
Now I need to go and change iqk FA to make it work
with Dk != Dv.
* iqk support for K head size != V head size
To not have compilation time explode, just
Dk = 192, Dv = 128 for now (DeepSeek)
* FA: very slightly faster for nq = 1 (TG)
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 61 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 278 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.h | 3 | ||||
-rw-r--r-- | src/llama.cpp | 8 |
4 files changed, 221 insertions, 129 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3867cf00..7b631177 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } + // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] } + // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] } + // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, softcap }; @@ -17436,10 +17440,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t Dk = nek0; + const int64_t Dv = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == Dv); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -17447,12 +17452,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == Dk); + GGML_ASSERT(nek0 == Dk); + GGML_ASSERT(nev0 == Dv); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(nev0 == Dv); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( int iq1 = (ith%ntg)*neq1g; int this_neq1 = MIN(neq1g, neq1-iq1); if (!iqk_flash_attn_noalibi(k->type, v->type, - D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), @@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + const int64_t Dkv = MAX(Dk, Dv); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:; float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, Dkv*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:; const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, Dk); // online softmax / attention // loop over n_kv and n_head_kv @@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:; float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1); s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask @@ -17610,14 +17617,14 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(Dv, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -17625,30 +17632,30 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(Dv, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, Dv); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(Dv, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < Dv; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(Dv, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t Dk = node->src[0]->ne[0]; + const int64_t Dv = node->src[2]->ne[0]; + const int64_t D = MAX(Dk, Dv); - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ee0af7e9..3b58495e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14879,10 +14879,60 @@ struct FlashQKV { using qkv_cache_t = float; #endif + template <typename VHelper> + inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + F16::Data vq[D/F16::block_size]; + if (fms.need_scaling[0] == 2) { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); + } else { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i); + if (fms.need_scaling[0] == 1) { + auto vms = F16::set1(fms.vms[0]); + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); + } + } + //F16::Data v[8]; + F16::Data v0, v1; + for (int l = 0; l < k_step; l += 4) { + auto vs0 = F16::set1(fms.cache[l + 0]); + auto vs1 = F16::set1(fms.cache[l + 1]); + auto vs2 = F16::set1(fms.cache[l + 2]); + auto vs3 = F16::set1(fms.cache[l + 3]); + //auto vs = F16::set4(fms.cache + l); + for (int i = 0; i < D/F16::block_size; i += 2) { + vh.load(l+0, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs0); + vh.load(l+1, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs1); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs1); + vh.load(l+2, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs2); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs2); + vh.load(l+3, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); + //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); + //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); + //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); + //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); + //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); + //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); + //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); + //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); + } + } + for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); + } + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template <typename VHelper> inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + if constexpr (q_step == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; @@ -14924,6 +14974,10 @@ struct FlashQKV { template <typename VHelper> inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + if (nq1 == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < nq1; ++j) { auto R = qkv_cache + D*j; @@ -15346,13 +15400,13 @@ struct FlashQKfp32 { } }; -template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, - FlashQKV<D, q_step, k_step>& fqkv, + FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { #ifdef __aarch64__ - float16_t q_f16[D*q_step]; + float16_t q_f16[Dk*q_step]; #endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { @@ -15365,7 +15419,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif @@ -15391,7 +15445,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); #endif @@ -15404,12 +15458,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in } } -template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, - FlashQKV<D, q_step, k_step>& fqkv, + FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; + typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)]; #if FA_TIMING Perf perf(false); #endif @@ -15420,7 +15474,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8); + HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15458,7 +15512,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8); + HelperQ80<Dk, QK8_0>::convert(n_left, stride_q, q, q8); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); @@ -15484,9 +15538,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to // process template parameter of such functions, but this would result in the compiler generating // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. -template <int D, int q_step, int k_step> +template <int Dk, int Dv, int q_step, int k_step> struct FlashAttn { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 256); + static_assert(Dv%F16::block_size == 0 && Dv <= 256); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15495,35 +15550,35 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || - std::is_same_v<KHelper, HelperQ60<D, k_step>>) { - compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || + std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) { + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } - else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { if (nq1 >= 8) { #if FA_TIMING auto t1 = Perf::cur_time(); - HelperQ80R4<D, k_step> khr4(nk1, kh); + HelperQ80R4<Dk, k_step> khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R4<D, k_step> khr4(nk1, kh); + HelperQ80R4<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<D, q_step, k_step, HelperQ80R4<D, k_step>, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else{ - compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } else { - compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } - FlashMS<q_step, k_step> fms; - FlashQKV<D, q_step, k_step> fqkv; + FlashMS<q_step, k_step> fms; + FlashQKV<Dv, q_step, k_step> fqkv; }; @@ -15756,7 +15811,22 @@ struct FlashQKbf16 { static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS<q_step, k_step>& fms) { #endif - { + if constexpr (q_step == 1) { + __m512bh vq[D/32]; + __m512bh vk[D/32]; + __m256 sum[8]; + for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i)); + for (int l = 0; l < k_step; l += 8) { + for (int k = 0; k < 8; ++k) { + kh.load(l+k, vk); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum)); + } + } + else { __m512bh qv[D/32]; if constexpr (D <= 128) { __m512bh vkh[D/4]; @@ -15856,9 +15926,10 @@ struct FlashQKbf16 { } }; -template <int D, int q_step, int k_step> +template <int Dk, int Dv, int q_step, int k_step> struct FlashAttnBF16 { - static_assert(D%32 == 0 && D <= 256); + static_assert(Dk%32 == 0 && Dk <= 256); + static_assert(Dv%32 == 0 && Dv <= 256); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15867,7 +15938,7 @@ struct FlashAttnBF16 { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - ggml_bf16_t q_bf16[q_step*D]; + ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); #endif @@ -15878,7 +15949,7 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16); + FlashQKbf16<Dk, q_step, k_step>::convert(stride_q, q, q_bf16); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15886,13 +15957,13 @@ struct FlashAttnBF16 { for (int k1 = 0; k1 < nk1/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); //perf.accum_nolock(1, t1); t1 = Perf::cur_time(); fqkv.accumulate_qkv(vh, fms); perf.accum_nolock(3, t1); #else - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif kh.next_block(); @@ -15916,10 +15987,10 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16); + FlashQKbf16<Dk, q_step, k_step>::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -15932,72 +16003,72 @@ struct FlashAttnBF16 { #endif } - FlashMS<q_step, k_step> fms; - FlashQKV<D, q_step, k_step> fqkv; + FlashMS<q_step, k_step> fms; + FlashQKV<Dv, q_step, k_step> fqkv; }; #endif -template <int D, int k_step, typename KHelper, typename VHelper> +template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { if (nk1 >= 256) { //4096) { if (nq1 >= 64) { - FlashAttn<D, 64, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 32) { - FlashAttn<D, 32, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 16) { - FlashAttn<D, 16, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttn<D, 8, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttn<D, 1, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #ifdef __AVX512BF16__ -template <int D, int k_step> +template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { - HelperBF16<D, k_step> kh(k, stride_k); - HelperBF16<D, k_step> vh(v, stride_v); + HelperBF16<Dk, k_step> kh(k, stride_k); + HelperBF16<Dv, k_step> vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16<D, 64, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } else if (nq1 >= 16) { - FlashAttnBF16<D, 16, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttnBF16<D, 8, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttnBF16<D, 1, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #endif -template <int D, int k_step, typename KHelper> +template <int Dk, int Dv, int k_step, typename KHelper> inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, @@ -16005,42 +16076,42 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { - HelperF16<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperF16<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #ifdef HAVE_FANCY_SIMD case GGML_TYPE_BF16: { - HelperBF16<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperBF16<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif case GGML_TYPE_Q8_0: { - HelperQ80<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ80<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ60<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ40<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ41<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperIQ4nl<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif default: break; } } -template <int D, int k_step> +template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, @@ -16048,29 +16119,29 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, switch (type_k) { case GGML_TYPE_F16: { - HelperF16<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperF16<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q8_0: { - HelperQ80<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ80<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ60<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ40<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ41<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperIQ4nl<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #endif default: break; @@ -16094,7 +16165,8 @@ inline bool flash_attn_is_supported(ggml_type type) { bool iqk_flash_attn_noalibi(int int_type_k, // type of k int int_type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq1, // number of columns in q int nk1, // number of rows in k int stride_q, // distance between q columns in bytes @@ -16114,7 +16186,9 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 - if (D != 64 && D != 96 && D != 128 && D != 256) return false; + if (Dk != Dv && Dk != 192 && Dv != 128) return false; + if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -16126,30 +16200,34 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_k == GGML_TYPE_BF16) { if (nk1%64 == 0) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } @@ -16159,41 +16237,45 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k #endif if (nk1%64 == 0) { - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6e27c614..b24dc7b2 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -23,7 +23,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, bool iqk_flash_attn_noalibi(int type_k, // type of k int type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq, // number of columns in q int nk, // number of rows in k int stride_q, // distance between q columns in bytes diff --git a/src/llama.cpp b/src/llama.cpp index b2553802..0817c53c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17768,10 +17768,10 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); - params.flash_attn = false; - } + //if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { + // LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); + // params.flash_attn = false; + //} if (params.type_v != GGML_TYPE_F16 && params.type_v != GGML_TYPE_BF16 && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |