diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 278 |
1 files changed, 180 insertions, 98 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ee0af7e9..3b58495e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14879,10 +14879,60 @@ struct FlashQKV { using qkv_cache_t = float; #endif + template <typename VHelper> + inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + F16::Data vq[D/F16::block_size]; + if (fms.need_scaling[0] == 2) { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); + } else { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i); + if (fms.need_scaling[0] == 1) { + auto vms = F16::set1(fms.vms[0]); + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); + } + } + //F16::Data v[8]; + F16::Data v0, v1; + for (int l = 0; l < k_step; l += 4) { + auto vs0 = F16::set1(fms.cache[l + 0]); + auto vs1 = F16::set1(fms.cache[l + 1]); + auto vs2 = F16::set1(fms.cache[l + 2]); + auto vs3 = F16::set1(fms.cache[l + 3]); + //auto vs = F16::set4(fms.cache + l); + for (int i = 0; i < D/F16::block_size; i += 2) { + vh.load(l+0, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs0); + vh.load(l+1, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs1); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs1); + vh.load(l+2, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs2); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs2); + vh.load(l+3, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); + //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); + //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); + //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); + //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); + //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); + //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); + //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); + //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); + } + } + for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); + } + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template <typename VHelper> inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + if constexpr (q_step == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; @@ -14924,6 +14974,10 @@ struct FlashQKV { template <typename VHelper> inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) { + if (nq1 == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < nq1; ++j) { auto R = qkv_cache + D*j; @@ -15346,13 +15400,13 @@ struct FlashQKfp32 { } }; -template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, - FlashQKV<D, q_step, k_step>& fqkv, + FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { #ifdef __aarch64__ - float16_t q_f16[D*q_step]; + float16_t q_f16[Dk*q_step]; #endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { @@ -15365,7 +15419,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif @@ -15391,7 +15445,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); #endif @@ -15404,12 +15458,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in } } -template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, - FlashQKV<D, q_step, k_step>& fqkv, + FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; + typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)]; #if FA_TIMING Perf perf(false); #endif @@ -15420,7 +15474,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8); + HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15458,7 +15512,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8); + HelperQ80<Dk, QK8_0>::convert(n_left, stride_q, q, q8); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); @@ -15484,9 +15538,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to // process template parameter of such functions, but this would result in the compiler generating // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. -template <int D, int q_step, int k_step> +template <int Dk, int Dv, int q_step, int k_step> struct FlashAttn { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 256); + static_assert(Dv%F16::block_size == 0 && Dv <= 256); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15495,35 +15550,35 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || - std::is_same_v<KHelper, HelperQ60<D, k_step>>) { - compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || + std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) { + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } - else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { if (nq1 >= 8) { #if FA_TIMING auto t1 = Perf::cur_time(); - HelperQ80R4<D, k_step> khr4(nk1, kh); + HelperQ80R4<Dk, k_step> khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R4<D, k_step> khr4(nk1, kh); + HelperQ80R4<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<D, q_step, k_step, HelperQ80R4<D, k_step>, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else{ - compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } else { - compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } - FlashMS<q_step, k_step> fms; - FlashQKV<D, q_step, k_step> fqkv; + FlashMS<q_step, k_step> fms; + FlashQKV<Dv, q_step, k_step> fqkv; }; @@ -15756,7 +15811,22 @@ struct FlashQKbf16 { static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS<q_step, k_step>& fms) { #endif - { + if constexpr (q_step == 1) { + __m512bh vq[D/32]; + __m512bh vk[D/32]; + __m256 sum[8]; + for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i)); + for (int l = 0; l < k_step; l += 8) { + for (int k = 0; k < 8; ++k) { + kh.load(l+k, vk); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum)); + } + } + else { __m512bh qv[D/32]; if constexpr (D <= 128) { __m512bh vkh[D/4]; @@ -15856,9 +15926,10 @@ struct FlashQKbf16 { } }; -template <int D, int q_step, int k_step> +template <int Dk, int Dv, int q_step, int k_step> struct FlashAttnBF16 { - static_assert(D%32 == 0 && D <= 256); + static_assert(Dk%32 == 0 && Dk <= 256); + static_assert(Dv%32 == 0 && Dv <= 256); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15867,7 +15938,7 @@ struct FlashAttnBF16 { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - ggml_bf16_t q_bf16[q_step*D]; + ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); #endif @@ -15878,7 +15949,7 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16); + FlashQKbf16<Dk, q_step, k_step>::convert(stride_q, q, q_bf16); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15886,13 +15957,13 @@ struct FlashAttnBF16 { for (int k1 = 0; k1 < nk1/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); //perf.accum_nolock(1, t1); t1 = Perf::cur_time(); fqkv.accumulate_qkv(vh, fms); perf.accum_nolock(3, t1); #else - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif kh.next_block(); @@ -15916,10 +15987,10 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16); + FlashQKbf16<Dk, q_step, k_step>::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); + FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -15932,72 +16003,72 @@ struct FlashAttnBF16 { #endif } - FlashMS<q_step, k_step> fms; - FlashQKV<D, q_step, k_step> fqkv; + FlashMS<q_step, k_step> fms; + FlashQKV<Dv, q_step, k_step> fqkv; }; #endif -template <int D, int k_step, typename KHelper, typename VHelper> +template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { if (nk1 >= 256) { //4096) { if (nq1 >= 64) { - FlashAttn<D, 64, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 32) { - FlashAttn<D, 32, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 16) { - FlashAttn<D, 16, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttn<D, 8, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttn<D, 1, k_step> fa(scale, softcap); + FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #ifdef __AVX512BF16__ -template <int D, int k_step> +template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { - HelperBF16<D, k_step> kh(k, stride_k); - HelperBF16<D, k_step> vh(v, stride_v); + HelperBF16<Dk, k_step> kh(k, stride_k); + HelperBF16<Dv, k_step> vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16<D, 64, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } else if (nq1 >= 16) { - FlashAttnBF16<D, 16, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttnBF16<D, 8, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttnBF16<D, 1, k_step> fa(scale, softcap); + FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #endif -template <int D, int k_step, typename KHelper> +template <int Dk, int Dv, int k_step, typename KHelper> inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, @@ -16005,42 +16076,42 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { - HelperF16<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperF16<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #ifdef HAVE_FANCY_SIMD case GGML_TYPE_BF16: { - HelperBF16<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperBF16<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif case GGML_TYPE_Q8_0: { - HelperQ80<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ80<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ60<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ40<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ41<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperIQ4nl<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif default: break; } } -template <int D, int k_step> +template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, @@ -16048,29 +16119,29 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, switch (type_k) { case GGML_TYPE_F16: { - HelperF16<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperF16<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q8_0: { - HelperQ80<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ80<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ60<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ40<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ41<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperIQ4nl<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #endif default: break; @@ -16094,7 +16165,8 @@ inline bool flash_attn_is_supported(ggml_type type) { bool iqk_flash_attn_noalibi(int int_type_k, // type of k int int_type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq1, // number of columns in q int nk1, // number of rows in k int stride_q, // distance between q columns in bytes @@ -16114,7 +16186,9 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 - if (D != 64 && D != 96 && D != 128 && D != 256) return false; + if (Dk != Dv && Dk != 192 && Dv != 128) return false; + if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -16126,30 +16200,34 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_k == GGML_TYPE_BF16) { if (nk1%64 == 0) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } @@ -16159,41 +16237,45 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k #endif if (nk1%64 == 0) { - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } |