diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-12 19:03:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-12 19:03:20 +0300 |
commit | 5017f8b3f04790de686c04e4ab23443c5ca02345 (patch) | |
tree | 7d5776ebefbcdb351e4f34d649ab572b8b9bf119 | |
parent | c920195edd80ab24beb9a0fd3e2f4df582e735d0 (diff) |
Quantized Flash Attention for all supported CPU platforms (#51)
* NEON Flash Attention: add support for Q8_0, Q4_0, Q4_1
* NEON Flash Attention: quantized K*Q for q4_0
I could finally take advantage of the matrix multiplication
templates. We get quite a bit of speedup that way for q4_0:
For Gemma-2b using mul_mat_qX_0_q8_0<DequantizerQ40, q_step>
results in PP-2048 = 287 t/s vs 268 t/s when converting the
q4_0 k-cache and Q to fp16 and using fp16 multiplication.
* NEON Flash Attention: quantized K*Q for q4_1
* NEON Flash Attention: quantized K*Q for q8_0
This makes quite a bit of difference:
For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs
178 t/s when converting things to fp16 and using fp16
matrix multiplication.
We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the
performance of PP-512. In contrast, llama.cpp with Q8_0
cache is 38% of PP-512.
* Zen4 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0
* AVX2 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0
* Tidy up FlashMS
* Delete no longer used stuff
With the usage of quantized matrix multiplications for
quantized k- and/or v-cache, we no longer need the
helper methods loading entire rows.
* Disallow mixing bf16 with other types for kv caches
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 899 |
1 files changed, 672 insertions, 227 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5a8cbce2..ce868514 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3045,7 +3045,6 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT { } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); - //s[iy*bs] = accm.result(acc[iy], iy); } } }; @@ -3212,6 +3211,35 @@ struct Q_Unpacker { } }; +struct Q8_0_x4_Unpacker { + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK8_0; } + Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + + const char * cx_0; + const block_q8_0_x4 * x; + size_t bx; + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + } + return scales; + } + inline auto set_block(int i) { + auto q8 = (const block_q8_0 *)(x + i); + qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); + return GGML_FP16_TO_FP32(q8->d); + } +}; + struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -5461,6 +5489,27 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> { }; +// TODO: handle case where row size is not a multiple of 128 +struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> { + + DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + auto scale = vld1_f16((const float16_t *)x[i].d); + for (int k = 0; k < 4; ++k) { + bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k); + bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16); + } + return scale; + } + +}; + struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> { DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} @@ -5529,9 +5578,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } - for (int i = 4*(nb/4); i < nb; ++i) { - q8.process_1_block(i, deq, acc); - } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq, acc); + //} for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); @@ -5591,9 +5640,9 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8.process_scales(i, deq1, sc16, acc); sum_4(i, deq1, q8, sc16, acc); } - for (int i = 4*(nb/4); i < nb; ++i) { - q8.process_1_block(i, deq1, acc); - } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq1, acc); + //} info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); } @@ -6548,70 +6597,275 @@ struct HelperF16 final : public BaseHelper<step> { } }; -#if defined __AVX2__ -template <int D, int step> -struct HelperQ80 final : public BaseHelper<step> { - static_assert(step == QK8_0); - using Base = BaseHelper<step>; - //using F16 = HelperF16<D, step>; - HelperQ80(const char * data, int stride) : Base(data, stride) {} +void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) { + const int nb = k / QK8_0; + const int nb4 = 4*(nb/4); - inline void load(int l1, F16::Data * vk) const { - auto dl = (const block_q8_0_x4 *)Base::lblock(l1); - if constexpr (D >= 128) { - F16::Data vd[4]; - for (int ib = 0; ib < D/128; ++ib) { - const auto& b8 = dl[ib]; - auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d)); - auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1); -#ifdef HAVE_FANCY_SIMD - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1); - vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0)); - vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1)); - vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2)); - vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3)); - for (int i = 0; i < 4; ++i) { - vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0)))); - vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1)))); - } +#if defined(__aarch64__) + block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } + } #else - vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0)); - vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1)); - vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); - vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); - for (int i = 0; i < 4; ++i) { - vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0))))); - vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8))))); - vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16))))); - vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24))))); - } + block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + const float d = maxScalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } + } #endif +} + +void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + const int nb4 = 4*(nb/4); + block_q8_1_x4 * y4 = (block_q8_1_x4 *)y; +#if defined(__aarch64__) + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } + + accv = vaddq_s32(accv, vi); + } + + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } else { - for (int i = 0; i < D/32; ++i) { - const auto& b8 = dl[i/4]; - int ii = i%4; - auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii])); -#ifdef HAVE_FANCY_SIMD - vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0)))); - vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1)))); + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } + } #else - vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 0))))); - vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8))))); - vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16))))); - vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24))))); -#endif - } + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Compute the sum of the quants and set y[i].s + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } else { + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); } } +#endif +} + +template <int D, int step> +struct HelperQ80 final : public BaseHelper<step> { + static_assert(step == QK8_0); + using Base = BaseHelper<step>; + using block_q8 = block_q8_0; + HelperQ80(const char * data, int stride) : Base(data, stride) {} + // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { - // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30 - // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc. - // i*F16::block_size/128 int j = F16::block_size*i; auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); int ii = (j/QK8_0)%4; +#ifdef __aarch64__ + const float16_t * d = (const float16_t *)dl->d; + auto vd = F16::set1(d[ii]); + auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); #ifdef HAVE_FANCY_SIMD v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); @@ -6620,11 +6874,25 @@ struct HelperQ80 final : public BaseHelper<step> { v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); #endif +#endif } - inline void load_2(int l1, F16::Data * vk) const { - load(l1+0, vk+0); - load(l1+1, vk+D/F16::block_size); + static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) { + GGML_ASSERT(nq <= step); + for (int i = 0; i < nq; ++i) { + quantize_row_q8_0(q, y, D); + q += stride_q; + y += D/QK8_0; + } + } + + static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) { + GGML_ASSERT(nq <= step); + for (int i = 0; i < nq; ++i) { + quantize_row_q8_1(q, y, D); + q += stride_q; + y += D/QK8_1; + } } }; @@ -6632,82 +6900,21 @@ template <int D, int step> struct HelperQ40 final : public BaseHelper<step> { static_assert(step == QK4_0); using Base = BaseHelper<step>; + using block_q8 = block_q8_0; HelperQ40(const char * data, int stride) : Base(data, stride) {} - - inline void load(int l1, F16::Data * vk) const { - auto dl = (const block_q4_0 *)Base::lblock(l1); - if constexpr (D >= 128) { - ggml_half aux[4]; - F16::Data vd[4]; - for (int ib = 0; ib < D/128; ++ib) { - for (int i = 0; i < 4; ++i) { - auto& b4 = dl[4*ib+i]; - aux[i] = b4.d; - auto q = _mm_loadu_si128((const __m128i *)b4.qs); - auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); - auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); -#ifdef HAVE_FANCY_SIMD - vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)); - vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)); -#else - auto ql16 = _mm256_cvtepi8_epi16(ql); - auto qh16 = _mm256_cvtepi8_epi16(qh); - vk[16*ib+4*i+0] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))); - vk[16*ib+4*i+1] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))); - vk[16*ib+4*i+2] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))); - vk[16*ib+4*i+3] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))); -#endif - } - auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux)); - auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1); -#ifdef HAVE_FANCY_SIMD - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1); - vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0)); - vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1)); - vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2)); - vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3)); - for (int i = 0; i < 4; ++i) { - vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]); - vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]); - } -#else - vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0)); - vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1)); - vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2)); - vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3)); - for (int i = 0; i < 4; ++i) { - vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+0]); - vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+1]); - vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+2]); - vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+3]); - } -#endif - } - } else { - for (int i = 0; i < D/32; ++i) { - auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); - auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); - auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); - auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); -#ifdef HAVE_FANCY_SIMD - vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); - vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); -#else - auto ql16 = _mm256_cvtepi8_epi16(ql); - auto qh16 = _mm256_cvtepi8_epi16(qh); - vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16)))); - vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1)))); - vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16)))); - vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1)))); -#endif - } - } - } - + // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto q = vld1q_u8(dl->qs); + q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + q = vaddq_s8(q, m8); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); #ifdef HAVE_FANCY_SIMD @@ -6721,50 +6928,37 @@ struct HelperQ40 final : public BaseHelper<step> { v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); #endif +#endif } - inline void load_2(int l1, F16::Data * vk) const { - load(l1+0, vk+0); - load(l1+1, vk+D/F16::block_size); - } - +#ifdef __AVX2__ const __m128i mask = _mm_set1_epi8(0xf); const __m128i m8 = _mm_set1_epi8(-8); +#else + const uint8x16_t mask = vdupq_n_u8(0xf); + const int8x16_t m8 = vdupq_n_s8(-8); +#endif }; template <int D, int step> struct HelperQ41 final : public BaseHelper<step> { static_assert(step == QK4_1); using Base = BaseHelper<step>; + using block_q8 = block_q8_1; HelperQ41(const char * data, int stride) : Base(data, stride) {} - - inline void load(int l1, F16::Data * vk) const { - auto dl = (const block_q4_1 *)Base::lblock(l1); - for (int i = 0; i < D/32; ++i) { - auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d)); - auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m)); - auto q = _mm_loadu_si128((const __m128i *)dl[i].qs); - auto ql = _mm_and_si128(q, mask); - auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); -#ifdef HAVE_FANCY_SIMD - vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); - vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm); -#else - auto ql16 = _mm256_cvtepi8_epi16(ql); - auto qh16 = _mm256_cvtepi8_epi16(qh); - vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))), vm); - vk[4*i+1] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))), vm); - vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm); - vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm); - vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm); -#endif - } - } - + // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto vm = F16::set1(*(const float16_t *)&dl->m); + auto q = vld1q_u8(dl->qs); + q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q)))); + v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q)))); +#else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); @@ -6779,16 +6973,15 @@ struct HelperQ41 final : public BaseHelper<step> { v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm); v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm); #endif +#endif } - inline void load_2(int l1, F16::Data * vk) const { - load(l1+0, vk+0); - load(l1+1, vk+D/F16::block_size); - } - +#ifdef __aarch64__ + const uint8x16_t mask = vdupq_n_u8(0xf); +#else const __m128i mask = _mm_set1_epi8(0xf); -}; #endif +}; template <int q_step, int k_step> struct FlashMS { @@ -6811,8 +7004,51 @@ struct FlashMS { } } + inline void update_M(int j, float smax) { + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = F16::set1(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + } + #ifdef __aarch64__ - inline void update_M_S(int j, float32x4_t * vk) { + inline void update_S(int j, float32x4_t * vk) { + auto vm = vdupq_n_f32(M[j]); + auto vsum = vdupq_n_f32(0); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = v_expf(vsubq_f32(vk[l], vm)); + vsum = vaddq_f32(vsum, vk[l]); + F16::store(cache + k_step*j + 4*l, vk[l]); + } + S[j] += vaddvq_f32(vsum); + } +#else + inline void update_S(int j, F16::Data * vk) { + auto vm = F16::set1(M[j]); + for (int l = 0; l < k_step/F16::block_size; ++l) { + vk[l] = v_expf(F16::sub(vk[l], vm)); + F16::store(cache + k_step*j + F16::block_size*l, vk[l]); + } + S[j] += F16::reduce_add<k_step>(vk); + } +#endif + +#ifdef __aarch64__ + inline float load_and_scale(int j, float32x4_t * vk) { float32x4_t vmax = vdupq_n_f32(-INFINITY); // Something goes wrong when storing and manipulating K*Q as fp16. // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). @@ -6850,37 +7086,40 @@ struct FlashMS { vmax = vmaxq_f32(vmax, vk[l]); } } - - float smax = vmaxvq_f32(vmax); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; + return vmaxvq_f32(vmax); + } + inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { + auto vzero = vdupq_n_f32(0); + auto vinf = vdupq_n_f32(-INFINITY); + for (int l = 0; l < k_step/8; ++l) { + auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l)); + auto vm1 = vzip1q_u16(vm, vm); + auto vm2 = vzip2q_u16(vm, vm); + auto kq = vld1q_f32_x2(cache + k_step*j + 8*l); + vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1), + vbicq_u32(vinf, vm1))); + vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), + vbicq_u32(vinf, vm2))); } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; + float32x4_t vmax = vdupq_n_f32(-INFINITY); + auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); + if (softcap <= 0.0f) { + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vk[l]); + vmax = vmaxq_f32(vmax, vk[l]); + } + } else { + auto v_softcap = vdupq_n_f32(softcap); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = vmulq_f32(vscale32, vk[l]); + vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l])); + vmax = vmaxq_f32(vmax, vk[l]); } - M[j] = smax; - } - auto vm = vdupq_n_f32(M[j]); - auto vsum = vdupq_n_f32(0); - for (int l = 0; l < k_step/4; ++l) { - vk[l] = v_expf(vsubq_f32(vk[l], vm)); - vsum = vaddq_f32(vsum, vk[l]); - F16::store(cache + k_step*j + 4*l, vk[l]); } - S[j] += vaddvq_f32(vsum); + return vmaxvq_f32(vmax); } #else - inline void update_M_S(int j, F16::Data * vk) { + inline float load_and_scale(int j, F16::Data * vk) { if (softcap <= 0.0f) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); } else { @@ -6890,32 +7129,67 @@ struct FlashMS { vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val))); } } - - float smax = F16::reduce_max<k_step>(vk); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; - } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; + return F16::reduce_max<k_step>(vk); + } + inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) { +#ifdef HAVE_FANCY_SIMD + auto vzero = _mm256_set1_epi16(0); + auto vinf = _mm512_set1_ps(-INFINITY); + if (softcap <= 0) { + for (int l = 0; l < k_step/F16::block_size; ++l) { + auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); + vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l)); + } + } else { + auto v_softcap = F16::set1(softcap); + for (int l = 0; l < k_step/F16::block_size; ++l) { + auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); + vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)))); } - M[j] = smax; } - auto vm = F16::set1(M[j]); +#else + auto vzero = _mm_set1_epi16(0); + auto vinf = F16::set1(-INFINITY); for (int l = 0; l < k_step/F16::block_size; ++l) { - vk[l] = v_expf(F16::sub(vk[l], vm)); - F16::store(cache + k_step*j + F16::block_size*l, vk[l]); + auto m128 = _mm_loadu_si128((const __m128i *)mask + l); + m128 = _mm_cmpeq_epi16(m128, vzero); + auto m256 = _mm256_cvtepi16_epi32(m128); + auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l); + vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + } + if (softcap <= 0) { + for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); + } else { + auto v_softcap = F16::set1(softcap); + for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l]))); } - S[j] += F16::reduce_add<k_step>(vk); +#endif + return F16::reduce_max<k_step>(vk); + } +#endif + +#ifdef __aarch64__ + inline void update_M_S(int j, float32x4_t * vk) { + float smax = load_and_scale(j, vk); + update_M(j, smax); + update_S(j, vk); + } + inline void update_M_S(int j, float32x4_t * vk, const char * mask) { + float smax = load_apply_mask_and_scale(j, vk, mask); + update_M(j, smax); + update_S(j, vk); + } +#else + inline void update_M_S(int j, F16::Data * vk) { + float smax = load_and_scale(j, vk); + update_M(j, smax); + update_S(j, vk); + } + inline void update_M_S(int j, F16::Data * vk, const char * mask) { + float smax = load_apply_mask_and_scale(j, vk, mask); + update_M(j, smax); + update_S(j, vk); } #endif @@ -7200,6 +7474,135 @@ struct FlashQKfp32 { } } #endif + + template <typename KHelper, typename block_q8> + static inline void mul_mask_kq(const KHelper& kh, int stride_m, + const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { + static_assert(q_step <= 8); + if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } + else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } + else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#endif + } + else { + GGML_ASSERT(false); + } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } +#else + F16::Data vk[k_step/F16::block_size]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } +#endif + } + template <typename KHelper, typename block_q8> + static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, + const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { + GGML_ASSERT(nq < 8); + if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } + else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } + else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } + else { + GGML_ASSERT(false); + } +#ifdef __aarch64__ + float32x4_t vk[k_step/4]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } +#else + F16::Data vk[k_step/F16::block_size]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } +#endif + } }; template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> @@ -7259,6 +7662,49 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in } } +template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper> +void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + FlashMS<q_step, k_step>& fms, + FlashQKV<D, q_step, k_step>& fqkv, + const float * q, const char * mask, float * qkv) { + typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; + for (int i1 = 0; i1 < nq1/q_step; ++i1) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv); + + q += q_step*stride_q; + mask += q_step*stride_m; + qkv += q_step*stride_qkv; + } + int n_left = nq1 - q_step*(nq1/q_step); + if (n_left > 0) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(n_left, vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + } +} + // Some of the methods in FlashAttn have two identical implementations that only differ by // one version using a loop over the template parameter q_step, while the other using a loop // over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, @@ -7280,8 +7726,14 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || + std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } else { + compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } } FlashMS<q_step, k_step> fms; @@ -7604,7 +8056,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperF16<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -7617,7 +8068,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ41<D, k_step> vh(v, stride_v); iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#endif default: break; } } @@ -7633,7 +8083,6 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperF16<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#ifdef __AVX2__ case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -7646,21 +8095,16 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ41<D, k_step> kh(k, stride_k); iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#endif default: break; } } inline bool flash_attn_is_supported(ggml_type type) { -#ifdef __AVX2__ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif -#else - if (type == GGML_TYPE_F16) return true; -#endif return false; } } @@ -7695,7 +8139,8 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float #ifdef __AVX512BF16__ - if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) { + if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) { + if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types switch (D) { case 64: iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; |