diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-04-29 07:19:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-29 07:19:43 +0200 |
commit | cda24b58cbef34154651d0083910fed860a506c1 (patch) | |
tree | 90cd3bd7f772c3b240a6553eca5e50edf95c53da /ggml/src/iqk/iqk_mul_mat.cpp | |
parent | baeefb4731fb24cdace168f6dbc74516d470efc0 (diff) |
CPU FA improvements (#351)
* FA: provide work buffer for K repacking
* Add header to avoid comp0iler warnings
* WIP
* WIP
* WIP
* WIP
* Slightly better
* WIP (Zen4)
* WIP
* Try to improve for unusual number of heads/number of threads
* Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA
* Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA
* Use Sum4q4 for q4_0
* WIP
* WIP
* Much better FA TG with q8_0 KV cache
Just repack it even for TG. But do the repacking for k_step rows,
not the whole K tensor.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 708 |
1 files changed, 623 insertions, 85 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e7ab2e5b..5f916584 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -19,6 +19,7 @@ #include "ggml-quants.h" #include "iqk_mul_mat.h" #include "iqk_quantize.h" +#include "iqk_flash_impl.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -6639,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } } +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif + template <int nrc_y> static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -8208,6 +8287,22 @@ template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 } } + inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } +}; + +template <typename Q8, typename Q8x4> struct Sum4q4 { + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 + auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 + auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 + auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 + auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 + auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 + auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 + return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); + } + inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } }; struct ScaleHelperQ8_0 { @@ -8362,6 +8457,7 @@ struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + inline __m256 vresult(__m256 acc, int) const { return acc; } }; template <int nrc_y> struct MinusType1 { @@ -8381,6 +8477,9 @@ template <int nrc_y> struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + inline __m256 vresult(__m256 acc, int iy) const { + return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); + } }; template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { @@ -8408,7 +8507,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); - const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } @@ -8417,6 +8516,36 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { info.store(ix, iy, accm.result(acc[iy], iy)); } } + template <typename Unpacker, typename Scales, typename Sum, typename Q8> + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + result[iy] = accm.vresult(acc[iy], iy); + } + } }; template <int nrc_y, bool is_multiple_of_4> @@ -8425,10 +8554,7 @@ using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>; template <int nrc_y, bool is_multiple_of_4> using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>; -using Sum4Type0 = Sum4<block_q8_0, block_q8_0_x4, SignedDot>; -using Sum4Type1 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot>; using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>; -//using Sum4TypeQ81 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot, false>; using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>; template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y> @@ -8443,6 +8569,19 @@ void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& in } } +template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y> +void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + GGML_ASSERT(nrc_x%2 == 0); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ix += 2) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + template <typename Unpacker, int nrc_y> void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8459,6 +8598,63 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info } } +inline __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} + +template <typename Unpacker, int nrc_y, int nrc_x> +void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8<nrc_y, block_q8_0> q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_0 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0<nrc_y, true> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0<nrc_y, false> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + + template <typename Unpacker, int nrc_y> void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8491,6 +8687,52 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template <typename Unpacker, int nrc_y, int nrc_x> +void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8<nrc_y, block_q8_2> q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_2 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1<nrc_y, true> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1<nrc_y, false> accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -8640,7 +8882,8 @@ struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_ }; struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> { Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; + //using Sum4T = Sum4TypeQ82; + using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>; inline static int block_size() { return QK4_0; } }; #ifdef HAVE_FANCY_SIMD @@ -15168,6 +15411,13 @@ struct F16 { auto v256 = _mm256_set_m128(v128, v128); return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm512_shuffle_ps(v, v, 0x00); + vs[1] = _mm512_shuffle_ps(v, v, 0x55); + vs[2] = _mm512_shuffle_ps(v, v, 0xaa); + vs[3] = _mm512_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15193,6 +15443,13 @@ struct F16 { auto v128 = _mm_loadu_ps(ptr); return _mm256_set_m128(v128, v128); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm256_shuffle_ps(v, v, 0x00); + vs[1] = _mm256_shuffle_ps(v, v, 0x55); + vs[2] = _mm256_shuffle_ps(v, v, 0xaa); + vs[3] = _mm256_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15388,7 +15645,119 @@ struct HelperQ80 final : public BaseHelper<step> { } } }; +} + +void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) { + repacked_type = int_type_k; + auto type_k = ggml_type(int_type_k); + if (type_k != GGML_TYPE_Q8_0 || nek0%QK8_0 != 0) return work; + int nrows = nek1*nek2*nek3; + if (nrows%8 != 0) return work; + repacked_type = int(GGML_TYPE_Q8_0_R8); + row_size = ggml_row_size(GGML_TYPE_Q8_0, nek0); + void * result = (char *)work + nrows*row_size; + int npt = 8*((nrows/8 + nth - 1)/nth); + int first = npt*ith; + if (first >= nrows) return result; + int last = std::min(first + npt, nrows); + const block_q8_0 * x8[8]; + auto y = (block_q8_0_r8 *)((char *)work + first*row_size); + int nblock = nek0/QK8_0; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = first; row < last; row += 8) { + int ik3 = row/(nek1*nek2); + int ik2 = (row - ik3*nek1*nek2)/nek1; + int ik1 = row - ik3*nek1*nek2 - ik2*nek1; + auto this_data = (const char *)data + ik1*nbk1 + ik2*nbk2 + ik3*nbk3; + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(this_data + k*nbk1); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); +#elif defined __ARM_NEON + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; +} +namespace { template <int D, int step> struct HelperQ80R8 : public BaseHelper<step> { using Base = BaseHelper<step>; @@ -15399,24 +15768,21 @@ struct HelperQ80R8 : public BaseHelper<step> { constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif + HelperQ80R8(const char * data, int stride) : Base(data, stride) {} HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) { - static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%8 == 0); + static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) { constexpr int nblock = D/QK8_0; - std::vector<block_q8_0_r8> result(nblock * nk/8); - auto y = result.data(); const block_q8_0 * x8[8]; #ifdef __ARM_NEON int8x16x2_t m0, m1, m2, m3; #endif for (int row = 0; row < nk; row += 8) { - for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride); for (int ib = 0; ib < nblock; ++ib) { for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ @@ -15498,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper<step> { } y += nblock; } + } + + static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%8 == 0); + constexpr int nblock = D/QK8_0; + std::vector<block_q8_0_r8> result(nblock * nk/8); + auto y = result.data(); + repack(nk, q8.data, q8.stride, y); return result; } @@ -15952,12 +16327,13 @@ struct FlashMS { } return F16::reduce_max<k_step>(vk); } - static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { - auto m128 = _mm_loadu_si128((const __m128i *)mask+l); - m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) { + return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l))); + //auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + //auto m256 = _mm256_cvtepi16_epi32(m128); + //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); } #ifdef __AVX512F__ static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { @@ -16087,7 +16463,6 @@ struct FlashQKV { accumulate_qkv_1(vh, fms); return; } - F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; if (fms.need_scaling[j] == 2) { @@ -16100,6 +16475,43 @@ struct FlashQKV { } } } +#ifdef __AVX512F__ + if constexpr ((D/F16::block_size)%4 == 0) { + F16::Data v[16]; + F16::Data vs[4]; + for (int i = 0; i < D/F16::block_size; i += 4) { + for (int l = 0; l < k_step; l += 4) { + for (int k = 0; k < 4; ++k) { + vh.load(l+k, i+0, v[4*k+0], v[4*k+1]); + vh.load(l+k, i+2, v[4*k+2], v[4*k+3]); + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto s3 = F16::load(R + F16::block_size*(i+2)); + auto s4 = F16::load(R + F16::block_size*(i+3)); + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[4*k+0], vs[k]); + s2 = F16::fmadd(s2, v[4*k+1], vs[k]); + s3 = F16::fmadd(s3, v[4*k+2], vs[k]); + s4 = F16::fmadd(s4, v[4*k+3], vs[k]); + } + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + F16::store(R + F16::block_size*(i+2), s3); + F16::store(R + F16::block_size*(i+3), s4); + } + } + } + return; + } +#endif + F16::Data v[8]; +#ifdef __AVX2__ + F16::Data vs[4]; +#endif for (int i = 0; i < D/F16::block_size; i += 2) { for (int l = 0; l < k_step; l += 4) { vh.load(l+0, i, v[0], v[4]); @@ -16110,6 +16522,13 @@ struct FlashQKV { auto R = qkv_cache + D*j; auto s1 = F16::load(R + F16::block_size*(i+0)); auto s2 = F16::load(R + F16::block_size*(i+1)); +#ifdef __AVX2__ + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[k+0], vs[k]); + s2 = F16::fmadd(s2, v[k+4], vs[k]); + } +#else auto vs = F16::set4(fms.cache + k_step*j + l); s1 = F16::fmadd_lane0(s1, v[0], vs); s2 = F16::fmadd_lane0(s2, v[4], vs); @@ -16119,6 +16538,7 @@ struct FlashQKV { s2 = F16::fmadd_lane2(s2, v[6], vs); s1 = F16::fmadd_lane3(s1, v[3], vs); s2 = F16::fmadd_lane3(s2, v[7], vs); +#endif F16::store(R + F16::block_size*(i+0), s1); F16::store(R + F16::block_size*(i+1), s2); } @@ -16239,7 +16659,8 @@ struct FlashQKV { // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. - qkv_cache_t qkv_cache[D*q_step] = {}; + qkv_cache_t qkv_cache[D*q_step]; // = {}; + //qkv_cache_t * qkv_cache; }; template <int D, int q_step, int k_step> @@ -16481,8 +16902,14 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq); #else #ifdef HAVE_FANCY_SIMD + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq); #endif #endif @@ -16493,10 +16920,15 @@ struct FlashQKfp32 { if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #else + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); #ifdef HAVE_FANCY_SIMD - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if constexpr (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } #endif - if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #endif } @@ -16514,17 +16946,23 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq); #endif } -#if GGML_IQK_FA_ALL_QUANTS else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); #else + if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq); #endif } +#if GGML_IQK_FA_ALL_QUANTS else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq); @@ -16664,8 +17102,29 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS<q_step, k_step>& fms, FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S) { - typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + float * M, float * S, char * qptr) { + auto q8 = (typename KHelper::block_q8 *)qptr; + if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { + if (nq1 == q_step) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; + HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + HelperQ80R8<Dk, k_step>::repack(k_step, kh.data, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + return; + } + } #if FA_TIMING Perf perf(false); #endif @@ -16731,6 +17190,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, #endif } +char * get_q_storage(size_t size) { + thread_local std::vector<char> q_storage; + if (q_storage.size() < size) q_storage.resize(size); + return q_storage.data(); +} + // Some of the methods in FlashAttn have two identical implementations that only differ by // one version using a loop over the template parameter q_step, while the other using a loop // over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, @@ -16753,44 +17218,57 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { - if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || + if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || - std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) { - compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { - if (nq1 >= 8) { + std::is_same_v<KHelper, HelperQ60<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ80<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> || + std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) { + constexpr size_t kMaxOnStackSize = 576; + auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + q_size = GGML_PAD(q_size, 64); + if (q_size > kMaxOnStackSize) { + auto qptr = get_q_storage(q_size); + if (nq1 >= 8) { + if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ80R8<Dk, k_step> khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ80R8<Dk, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ80R8<Dk, k_step> khr4(nk1, kh); + HelperQ80R8<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ - compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - } - else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { - if (nq1 >= 8) { + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + + } + if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + } + } compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + } - } else { + else { + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); + } + } + else { compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } @@ -17234,39 +17712,61 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - if (nk1 >= 256) { //4096) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nk1 >= 512) { + if (nq1 >= 128) { + int n_step = nq1/128; + FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); + fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(128*n_step)) return; + } if (nq1 >= 64) { + int n_step = nq1/64; FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(64*n_step)) return; } if (nq1 >= 32) { + int n_step = nq1/32; FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(32*n_step)) return; } if (nq1 >= 16) { + int n_step = nq1/16; FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(16*n_step)) return; } } if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(8*n_step)) return; } else if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(4*n_step)) return; } else if (nq1 >= 2) { + int n_step = nq1/2; FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - } - else { - FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(2*n_step)) return; } + FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } #ifdef __AVX512BF16__ @@ -17327,11 +17827,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ60<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); @@ -17360,6 +17860,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; + case GGML_TYPE_Q8_0_R8: { + HelperQ80R8<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; case GGML_TYPE_Q8_KV: { HelperQ8KV<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17368,11 +17872,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ60<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17393,9 +17897,10 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif #if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || - type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8 + || type == GGML_TYPE_Q4_0) return true; #endif return false; } @@ -17404,25 +17909,35 @@ template <int step_k, typename KHelper, typename VHelper> inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; } - else if (nq1 >= 4) { + if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn<576, 512, 4, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; } - else { - FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - } - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} else { - // FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} + if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(2*n_step)) return; + } + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } template <int step_k> @@ -17436,6 +17951,12 @@ inline bool iqk_deepseek_helper(ggml_type type_k, iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } if (type_k == GGML_TYPE_Q6_0) { HelperQ60<576, step_k> kh((const char *)k, stride_k); HelperQ60<512, step_k> vh((const char *)v, stride_v); @@ -17558,6 +18079,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k } #endif + if (nk1%128 == 0) { + switch (Dk) { + case 64: + iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 96: + iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 128: + iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 256: + iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + default: + return false; + } + return true; + } if (nk1%64 == 0) { switch (Dk) { case 64: |