diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-04-29 07:19:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-29 07:19:43 +0200 |
commit | cda24b58cbef34154651d0083910fed860a506c1 (patch) | |
tree | 90cd3bd7f772c3b240a6553eca5e50edf95c53da | |
parent | baeefb4731fb24cdace168f6dbc74516d470efc0 (diff) |
CPU FA improvements (#351)
* FA: provide work buffer for K repacking
* Add header to avoid comp0iler warnings
* WIP
* WIP
* WIP
* WIP
* Slightly better
* WIP (Zen4)
* WIP
* Try to improve for unusual number of heads/number of threads
* Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA
* Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA
* Use Sum4q4 for q4_0
* WIP
* WIP
* Much better FA TG with q8_0 KV cache
Just repack it even for TG. But do the repacking for k_step rows,
not the whole K tensor.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 29 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_flash_attn.cpp | 147 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_flash_impl.h | 4 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 708 |
4 files changed, 763 insertions, 125 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3cfd9a0..4cd18a28 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread #if GGML_USE_IQK_MULMAT + size_t qsize = 0; const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; + if (k->type == GGML_TYPE_Q8_0) { + qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); + } if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { - int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks))); + int gcd = simple_gcd(k->ne[2], n_tasks); + int nth_k = n_tasks/gcd; + int nek2_k = k->ne[2]/gcd; + int nchunk = nek2_k*k->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); int nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } else { int nstep_k = k->ne[1]/32; int gcd_k = simple_gcd(nstep_k, n_tasks); @@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); size += q->ne[2]*row_size; } - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } } + } else { + cur = MAX(cur, qsize); } #endif } break; diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 0de68b94..fd0d5dd0 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) { } return a; } +inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) { + if (Mj == -INFINITY) return; + if (Mj > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj; + } else { + float c = exp(M - Mj); + S = c*S + Sj; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj; + } else { + float c = exp(Mj - M); + S += c*Sj; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } +} } // TODO: get the ggml_type enum here without polution @@ -34,7 +52,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, int ne2, int ne1, long nb1, - int int_type_k, // type of k + int int_type_k_in, // type of k int int_type_v, // type of v int Dk, // K head size int Dv, // V head size @@ -51,7 +69,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) - [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, int ith, int nth) { if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; @@ -61,6 +79,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rk3 = neq3/nek3; int rv3 = neq3/nev3; + int int_type_k = int_type_k_in; + auto work_buffer = work_buffer_in; + if (neq1 >= 8 || rk2 >= 8) { + uint64_t row_size = 0; + work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + if (int_type_k != int_type_k_in) { + stride_k = row_size; + nbk2 = stride_k*nek1; + nbk3 = nbk2*nek2; + k = work_buffer_in; + barrier(barrier_data); + } + } + //uint64_t row_size = 0; + //auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + //if (int_type_k != int_type_k_in) { + // stride_k = row_size; + // nbk2 = stride_k*nek1; + // nbk3 = nbk2*nek2; + // k = work_buffer_in; + // barrier(barrier_data); + //} + // Getting confused all the time about where to load data from and store the results to // (especially when combining the results from the threads). // So, for now, making it work just for MLA (nek2 = 1). @@ -128,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto Mj = R + Dv*nq_this_j; auto Sj = Mj + nq_this_j; R += jj*Dv; - if (Mj[jj] == -INFINITY) continue; - if (Mj[jj] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[jj]; - } else { - float c = exp(M - Mj[jj]); - S = c*S + Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[jj]; - } else { - float c = exp(Mj[jj] - M); - S += c*Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; @@ -154,10 +180,72 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float } if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { - int nk = std::max(1, 32 * (nek2*nek1/(32*nth))); + auto result_size = (Dv + 16)*rk2*sizeof(float); + int gcd = simple_gcd(nek2, nth); + if (false && gcd > 1) { + int nth_g = nth/gcd; + int ith_g = ith%nth_g; + int nek1_32 = nek1/32; + int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; + int ith_mid = nth_g; + if (nek1_pt*nth_g > nek1_32) { + ith_mid = nek1_32 - nth_g*(nek1_pt - 1); + } + nek1_pt *= 32; + int nek1_mid = ith_mid*nek1_pt; + int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; + for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { + int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; + auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + float M = -INFINITY, S = 0; + for (int ig = 0; ig < nth_g; ++ig) { + int istep_k = ik02*nth_g + ig; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + int nth_k = nth/gcd; + int nek2_k = nek2/gcd; + int nchunk = nek2_k*nek1/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (nek2*nek1/(32*nth)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (nek2*nek1/(32*nth)); int nkk = (nek1 + nk - 1)/nk; int nstep_k = nek2*nkk; - auto result_size = (Dv + 16)*rk2*sizeof(float); //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { int ik02 = istep_k/nkk; @@ -183,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int ik02 = iq2/rk2; int il = iq2 - ik02*rk2; auto Racc = qkv + iq2*nb1/sizeof(float); - std::memset(Racc, 0, Dv*sizeof(float)); + //std::memset(Racc, 0, Dv*sizeof(float)); float M = -INFINITY, S = 0; for (int ikk = 0; ikk < nkk; ++ikk) { int istep_k = ik02*nkk + ikk; @@ -191,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const float * R = this_result + il*Dv; const float * Mj = this_result + Dv*rk2; const float * Sj = Mj + rk2; - if (Mj[il] == -INFINITY) continue; - if (Mj[il] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[il]; - } else { - float c = exp(M - Mj[il]); - S = c*S + Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[il]; - } else { - float c = exp(Mj[il] - M); - S += c*Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h index 68802927..6f62e56b 100644 --- a/ggml/src/iqk/iqk_flash_impl.h +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -6,6 +6,8 @@ #pragma once +#include <cstdint> + bool iqk_flash_attn_impl(int type_k, // type of k int type_v, // type of v int Dk, // K head size @@ -27,3 +29,5 @@ bool iqk_flash_attn_impl(int type_k, // type of k float * M, float * S); +void * iqk_repack_k(int type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * k, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e7ab2e5b..5f916584 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -19,6 +19,7 @@ #include "ggml-quants.h" #include "iqk_mul_mat.h" #include "iqk_quantize.h" +#include "iqk_flash_impl.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -6639,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } } +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif + template <int nrc_y> static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -8208,6 +8287,22 @@ template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 } } + inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } +}; + +template <typename Q8, typename Q8x4> struct Sum4q4 { + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 + auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 + auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 + auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 + auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 + auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 + auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 + return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); + } + inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } }; struct ScaleHelperQ8_0 { @@ -8362,6 +8457,7 @@ struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + inline __m256 vresult(__m256 acc, int) const { return acc; } }; template <int nrc_y> struct MinusType1 { @@ -8381,6 +8477,9 @@ template <int nrc_y> struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + inline __m256 vresult(__m256 acc, int iy) const { + return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); + } }; template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { @@ -8408,7 +8507,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); - const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } @@ -8417,6 +8516,36 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { info.store(ix, iy, accm.result(acc[iy], iy)); } } + template <typename Unpacker, typename Scales, typename Sum, typename Q8> + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + result[iy] = accm.vresult(acc[iy], iy); + } + } }; template <int nrc_y, bool is_multiple_of_4> @@ -8425,10 +8554,7 @@ using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>; template <int nrc_y, bool is_multiple_of_4> using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>; -using Sum4Type0 = Sum4<block_q8_0, block_q8_0_x4, SignedDot>; -using Sum4Type1 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot>; using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>; -//using Sum4TypeQ81 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot, false>; using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>; template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y> @@ -8443,6 +8569,19 @@ void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& in } } +template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y> +void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + GGML_ASSERT(nrc_x%2 == 0); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ix += 2) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + template <typename Unpacker, int nrc_y> void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8459,6 +8598,63 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info } } +inline __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} + +template <typename Unpacker, int nrc_y, int nrc_x> +void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8<nrc_y, block_q8_0> q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_0 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0<nrc_y, true> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0<nrc_y, false> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + + template <typename Unpacker, int nrc_y> void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8491,6 +8687,52 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template <typename Unpacker, int nrc_y, int nrc_x> +void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8<nrc_y, block_q8_2> q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_2 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1<nrc_y, true> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1<nrc_y, false> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -8640,7 +8882,8 @@ struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_ }; struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> { Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; + //using Sum4T = Sum4TypeQ82; + using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>; inline static int block_size() { return QK4_0; } }; #ifdef HAVE_FANCY_SIMD @@ -15168,6 +15411,13 @@ struct F16 { auto v256 = _mm256_set_m128(v128, v128); return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm512_shuffle_ps(v, v, 0x00); + vs[1] = _mm512_shuffle_ps(v, v, 0x55); + vs[2] = _mm512_shuffle_ps(v, v, 0xaa); + vs[3] = _mm512_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15193,6 +15443,13 @@ struct F16 { auto v128 = _mm_loadu_ps(ptr); return _mm256_set_m128(v128, v128); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm256_shuffle_ps(v, v, 0x00); + vs[1] = _mm256_shuffle_ps(v, v, 0x55); + vs[2] = _mm256_shuffle_ps(v, v, 0xaa); + vs[3] = _mm256_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15388,7 +15645,119 @@ struct HelperQ80 final : public BaseHelper<step> { } } }; +} + +void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) { + repacked_type = int_type_k; + auto type_k = ggml_type(int_type_k); + if (type_k != GGML_TYPE_Q8_0 || nek0%QK8_0 != 0) return work; + int nrows = nek1*nek2*nek3; + if (nrows%8 != 0) return work; + repacked_type = int(GGML_TYPE_Q8_0_R8); + row_size = ggml_row_size(GGML_TYPE_Q8_0, nek0); + void * result = (char *)work + nrows*row_size; + int npt = 8*((nrows/8 + nth - 1)/nth); + int first = npt*ith; + if (first >= nrows) return result; + int last = std::min(first + npt, nrows); + const block_q8_0 * x8[8]; + auto y = (block_q8_0_r8 *)((char *)work + first*row_size); + int nblock = nek0/QK8_0; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = first; row < last; row += 8) { + int ik3 = row/(nek1*nek2); + int ik2 = (row - ik3*nek1*nek2)/nek1; + int ik1 = row - ik3*nek1*nek2 - ik2*nek1; + auto this_data = (const char *)data + ik1*nbk1 + ik2*nbk2 + ik3*nbk3; + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(this_data + k*nbk1); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); +#elif defined __ARM_NEON + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; +} +namespace { template <int D, int step> struct HelperQ80R8 : public BaseHelper<step> { using Base = BaseHelper<step>; @@ -15399,24 +15768,21 @@ struct HelperQ80R8 : public BaseHelper<step> { constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif + HelperQ80R8(const char * data, int stride) : Base(data, stride) {} HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) { - static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%8 == 0); + static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) { constexpr int nblock = D/QK8_0; - std::vector<block_q8_0_r8> result(nblock * nk/8); - auto y = result.data(); const block_q8_0 * x8[8]; #ifdef __ARM_NEON int8x16x2_t m0, m1, m2, m3; #endif for (int row = 0; row < nk; row += 8) { - for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride); for (int ib = 0; ib < nblock; ++ib) { for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ @@ -15498,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper<step> { } y += nblock; } + } + + static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%8 == 0); + constexpr int nblock = D/QK8_0; + std::vector<block_q8_0_r8> result(nblock * nk/8); + auto y = result.data(); + repack(nk, q8.data, q8.stride, y); return result; } @@ -15952,12 +16327,13 @@ struct FlashMS { } return F16::reduce_max<k_step>(vk); } - static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { - auto m128 = _mm_loadu_si128((const __m128i *)mask+l); - m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) { + return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l))); + //auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + //auto m256 = _mm256_cvtepi16_epi32(m128); + //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); } #ifdef __AVX512F__ static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { @@ -16087,7 +16463,6 @@ struct FlashQKV { accumulate_qkv_1(vh, fms); return; } - F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; if (fms.need_scaling[j] == 2) { @@ -16100,6 +16475,43 @@ struct FlashQKV { } } } +#ifdef __AVX512F__ + if constexpr ((D/F16::block_size)%4 == 0) { + F16::Data v[16]; + F16::Data vs[4]; + for (int i = 0; i < D/F16::block_size; i += 4) { + for (int l = 0; l < k_step; l += 4) { + for (int k = 0; k < 4; ++k) { + vh.load(l+k, i+0, v[4*k+0], v[4*k+1]); + vh.load(l+k, i+2, v[4*k+2], v[4*k+3]); + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto s3 = F16::load(R + F16::block_size*(i+2)); + auto s4 = F16::load(R + F16::block_size*(i+3)); + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[4*k+0], vs[k]); + s2 = F16::fmadd(s2, v[4*k+1], vs[k]); + s3 = F16::fmadd(s3, v[4*k+2], vs[k]); + s4 = F16::fmadd(s4, v[4*k+3], vs[k]); + } + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + F16::store(R + F16::block_size*(i+2), s3); + F16::store(R + F16::block_size*(i+3), s4); + } + } + } + return; + } +#endif + F16::Data v[8]; +#ifdef __AVX2__ + F16::Data vs[4]; +#endif for (int i = 0; i < D/F16::block_size; i += 2) { for (int l = 0; l < k_step; l += 4) { vh.load(l+0, i, v[0], v[4]); @@ -16110,6 +16522,13 @@ struct FlashQKV { auto R = qkv_cache + D*j; auto s1 = F16::load(R + F16::block_size*(i+0)); auto s2 = F16::load(R + F16::block_size*(i+1)); +#ifdef __AVX2__ + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[k+0], vs[k]); + s2 = F16::fmadd(s2, v[k+4], vs[k]); + } +#else auto vs = F16::set4(fms.cache + k_step*j + l); s1 = F16::fmadd_lane0(s1, v[0], vs); s2 = F16::fmadd_lane0(s2, v[4], vs); @@ -16119,6 +16538,7 @@ struct FlashQKV { s2 = F16::fmadd_lane2(s2, v[6], vs); s1 = F16::fmadd_lane3(s1, v[3], vs); s2 = F16::fmadd_lane3(s2, v[7], vs); +#endif F16::store(R + F16::block_size*(i+0), s1); F16::store(R + F16::block_size*(i+1), s2); } @@ -16239,7 +16659,8 @@ struct FlashQKV { // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. - qkv_cache_t qkv_cache[D*q_step] = {}; + qkv_cache_t qkv_cache[D*q_step]; // = {}; + //qkv_cache_t * qkv_cache; }; template <int D, int q_step, int k_step> @@ -16481,8 +16902,14 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq); #else #ifdef HAVE_FANCY_SIMD + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq); #endif #endif @@ -16493,10 +16920,15 @@ struct FlashQKfp32 { if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #else + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); #ifdef HAVE_FANCY_SIMD - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if constexpr (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } #endif - if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #endif } @@ -16514,17 +16946,23 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq); #endif } -#if GGML_IQK_FA_ALL_QUANTS else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq); #endif } +#if GGML_IQK_FA_ALL_QUANTS else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq); @@ -16664,8 +17102,29 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS<q_step, k_step>& fms, FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S) { - typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + float * M, float * S, char * qptr) { + auto q8 = (typename KHelper::block_q8 *)qptr; + if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { + if (nq1 == q_step) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; + HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + HelperQ80R8<Dk, k_step>::repack(k_step, kh.data, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + return; + } + } #if FA_TIMING Perf perf(false); #endif @@ -16731,6 +17190,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, #endif } +char * get_q_storage(size_t size) { + thread_local std::vector<char> q_storage; + if (q_storage.size() < size) q_storage.resize(size); + return q_storage.data(); +} + // Some of the methods in FlashAttn have two identical implementations that only differ by // one version using a loop over the template parameter q_step, while the other using a loop // over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, @@ -16753,44 +17218,57 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { - if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || + if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || - std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) { - compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { - if (nq1 >= 8) { + std::is_same_v<KHelper, HelperQ60<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ80<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) { + constexpr size_t kMaxOnStackSize = 576; + auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + q_size = GGML_PAD(q_size, 64); + if (q_size > kMaxOnStackSize) { + auto qptr = get_q_storage(q_size); + if (nq1 >= 8) { + if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ80R8<Dk, k_step> khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ80R8<Dk, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ80R8<Dk, k_step> khr4(nk1, kh); + HelperQ80R8<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ - compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - } - else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { - if (nq1 >= 8) { + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + + } + if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + } + } compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + } - } else { + else { + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); + } + } + else { compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } @@ -17234,39 +17712,61 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - if (nk1 >= 256) { //4096) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nk1 >= 512) { + if (nq1 >= 128) { + int n_step = nq1/128; + FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); + fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(128*n_step)) return; + } if (nq1 >= 64) { + int n_step = nq1/64; FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(64*n_step)) return; } if (nq1 >= 32) { + int n_step = nq1/32; FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(32*n_step)) return; } if (nq1 >= 16) { + int n_step = nq1/16; FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(16*n_step)) return; } } if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(8*n_step)) return; } else if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(4*n_step)) return; } else if (nq1 >= 2) { + int n_step = nq1/2; FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - } - else { - FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(2*n_step)) return; } + FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } #ifdef __AVX512BF16__ @@ -17327,11 +17827,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ60<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); @@ -17360,6 +17860,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; + case GGML_TYPE_Q8_0_R8: { + HelperQ80R8<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; case GGML_TYPE_Q8_KV: { HelperQ8KV<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17368,11 +17872,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ60<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17393,9 +17897,10 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif #if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || - type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8 + || type == GGML_TYPE_Q4_0) return true; #endif return false; } @@ -17404,25 +17909,35 @@ template <int step_k, typename KHelper, typename VHelper> inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; } - else if (nq1 >= 4) { + if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn<576, 512, 4, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; } - else { - FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - } - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} else { - // FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} + if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(2*n_step)) return; + } + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } template <int step_k> @@ -17436,6 +17951,12 @@ inline bool iqk_deepseek_helper(ggml_type type_k, iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } if (type_k == GGML_TYPE_Q6_0) { HelperQ60<576, step_k> kh((const char *)k, stride_k); HelperQ60<512, step_k> vh((const char *)v, stride_v); @@ -17558,6 +18079,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k } #endif + if (nk1%128 == 0) { + switch (Dk) { + case 64: + iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 96: + iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 128: + iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 256: + iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + default: + return false; + } + return true; + } if (nk1%64 == 0) { switch (Dk) { case 64: |