summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-12 19:03:20 +0300
committerGitHub <noreply@github.com>2024-09-12 19:03:20 +0300
commit5017f8b3f04790de686c04e4ab23443c5ca02345 (patch)
tree7d5776ebefbcdb351e4f34d649ab572b8b9bf119
parentc920195edd80ab24beb9a0fd3e2f4df582e735d0 (diff)
Quantized Flash Attention for all supported CPU platforms (#51)
* NEON Flash Attention: add support for Q8_0, Q4_0, Q4_1 * NEON Flash Attention: quantized K*Q for q4_0 I could finally take advantage of the matrix multiplication templates. We get quite a bit of speedup that way for q4_0: For Gemma-2b using mul_mat_qX_0_q8_0<DequantizerQ40, q_step> results in PP-2048 = 287 t/s vs 268 t/s when converting the q4_0 k-cache and Q to fp16 and using fp16 multiplication. * NEON Flash Attention: quantized K*Q for q4_1 * NEON Flash Attention: quantized K*Q for q8_0 This makes quite a bit of difference: For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs 178 t/s when converting things to fp16 and using fp16 matrix multiplication. We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the performance of PP-512. In contrast, llama.cpp with Q8_0 cache is 38% of PP-512. * Zen4 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0 * AVX2 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0 * Tidy up FlashMS * Delete no longer used stuff With the usage of quantized matrix multiplications for quantized k- and/or v-cache, we no longer need the helper methods loading entire rows. * Disallow mixing bf16 with other types for kv caches --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp899
1 files changed, 672 insertions, 227 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 5a8cbce2..ce868514 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -3045,7 +3045,6 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, accm.result(acc[iy], iy));
- //s[iy*bs] = accm.result(acc[iy], iy);
}
}
};
@@ -3212,6 +3211,35 @@ struct Q_Unpacker {
}
};
+struct Q8_0_x4_Unpacker {
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK8_0; }
+ Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
+
+ const char * cx_0;
+ const block_q8_0_x4 * x;
+ size_t bx;
+
+ __m256i qx[4];
+
+ inline const __m256i* quants() const { return qx; }
+
+ inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); }
+
+ inline auto set_block_4(int i) {
+ auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ }
+ return scales;
+ }
+ inline auto set_block(int i) {
+ auto q8 = (const block_q8_0 *)(x + i);
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs);
+ return GGML_FP16_TO_FP32(q8->d);
+ }
+};
+
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -5461,6 +5489,27 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
};
+// TODO: handle case where row size is not a multiple of 128
+struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> {
+
+ DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.b[0] = vld1q_s8(x[i].qs);
+ bits.b[1] = vld1q_s8(x[i].qs+16);
+ }
+
+ inline float16x4_t new_block(int i) {
+ auto scale = vld1_f16((const float16_t *)x[i].d);
+ for (int k = 0; k < 4; ++k) {
+ bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k);
+ bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16);
+ }
+ return scale;
+ }
+
+};
+
struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
@@ -5529,9 +5578,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i
q8.process_scales(i, deq, sc16, acc);
sum_4(i, deq, q8, sc16, acc);
}
- for (int i = 4*(nb/4); i < nb; ++i) {
- q8.process_1_block(i, deq, acc);
- }
+ //for (int i = 4*(nb/4); i < nb; ++i) {
+ // q8.process_1_block(i, deq, acc);
+ //}
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
@@ -5591,9 +5640,9 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8&
q8.process_scales(i, deq1, sc16, acc);
sum_4(i, deq1, q8, sc16, acc);
}
- for (int i = 4*(nb/4); i < nb; ++i) {
- q8.process_1_block(i, deq1, acc);
- }
+ //for (int i = 4*(nb/4); i < nb; ++i) {
+ // q8.process_1_block(i, deq1, acc);
+ //}
info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
}
@@ -6548,70 +6597,275 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
-#if defined __AVX2__
-template <int D, int step>
-struct HelperQ80 final : public BaseHelper<step> {
- static_assert(step == QK8_0);
- using Base = BaseHelper<step>;
- //using F16 = HelperF16<D, step>;
- HelperQ80(const char * data, int stride) : Base(data, stride) {}
+void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) {
+ const int nb = k / QK8_0;
+ const int nb4 = 4*(nb/4);
- inline void load(int l1, F16::Data * vk) const {
- auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
- if constexpr (D >= 128) {
- F16::Data vd[4];
- for (int ib = 0; ib < D/128; ++ib) {
- const auto& b8 = dl[ib];
- auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d));
- auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
- vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
- vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
- vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
- vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
- for (int i = 0; i < 4; ++i) {
- vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0))));
- vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1))));
- }
+#if defined(__aarch64__)
+ block_q8_0_x4 * y4 = (block_q8_0_x4 *)y;
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+ }
+ }
#else
- vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0));
- vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1));
- vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2));
- vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3));
- for (int i = 0; i < 4; ++i) {
- vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 0)))));
- vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+ 8)))));
- vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+16)))));
- vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*i+24)))));
- }
+ block_q8_0_x4 * y4 = (block_q8_0_x4 *)y;
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ const float d = maxScalar / 127.f;
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
+ }
#endif
+}
+
+void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ const int nb4 = 4*(nb/4);
+ block_q8_1_x4 * y4 = (block_q8_1_x4 *)y;
+#if defined(__aarch64__)
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ int32x4_t accv = vdupq_n_s32(0);
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
+
+ accv = vaddq_s32(accv, vi);
+ }
+
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
} else {
- for (int i = 0; i < D/32; ++i) {
- const auto& b8 = dl[i/4];
- int ii = i%4;
- auto vd = F16::set1(GGML_FP16_TO_FP32(b8.d[ii]));
-#ifdef HAVE_FANCY_SIMD
- vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
- vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
+ y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+ }
+ }
#else
- vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 0)))));
- vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+ 8)))));
- vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+16)))));
- vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(b8.qs+32*ii+24)))));
-#endif
- }
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float max_scalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = max_scalar / 127.f;
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+ // Compute the sum of the quants and set y[i].s
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ }
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
}
}
+#endif
+}
+
+template <int D, int step>
+struct HelperQ80 final : public BaseHelper<step> {
+ static_assert(step == QK8_0);
+ using Base = BaseHelper<step>;
+ using block_q8 = block_q8_0;
+ HelperQ80(const char * data, int stride) : Base(data, stride) {}
+ // Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
- // Say D = 256 -> i is 0, 2, 4, 6, 8, ..., 28, 30. 128/8 = 16 -> we use 1st block of 128 for i = 0, 2, ..., 14, second for i = 16, 18, ..., 30
- // i = 0, 2 -> ii = 0, i = 4, 6 -> ii = 1, i = 8, 10 -> ii = 2, i = 12, 14 -> ii = 3, i = 16, 18 -> ii = 0, etc.
- // i*F16::block_size/128
int j = F16::block_size*i;
auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0);
int ii = (j/QK8_0)%4;
+#ifdef __aarch64__
+ const float16_t * d = (const float16_t *)dl->d;
+ auto vd = F16::set1(d[ii]);
+ auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
+#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii]));
#ifdef HAVE_FANCY_SIMD
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
@@ -6620,11 +6874,25 @@ struct HelperQ80 final : public BaseHelper<step> {
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32)))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8)))));
#endif
+#endif
}
- inline void load_2(int l1, F16::Data * vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+D/F16::block_size);
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
+ GGML_ASSERT(nq <= step);
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_0(q, y, D);
+ q += stride_q;
+ y += D/QK8_0;
+ }
+ }
+
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
+ GGML_ASSERT(nq <= step);
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_1(q, y, D);
+ q += stride_q;
+ y += D/QK8_1;
+ }
}
};
@@ -6632,82 +6900,21 @@ template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
static_assert(step == QK4_0);
using Base = BaseHelper<step>;
+ using block_q8 = block_q8_0;
HelperQ40(const char * data, int stride) : Base(data, stride) {}
-
- inline void load(int l1, F16::Data * vk) const {
- auto dl = (const block_q4_0 *)Base::lblock(l1);
- if constexpr (D >= 128) {
- ggml_half aux[4];
- F16::Data vd[4];
- for (int ib = 0; ib < D/128; ++ib) {
- for (int i = 0; i < 4; ++i) {
- auto& b4 = dl[4*ib+i];
- aux[i] = b4.d;
- auto q = _mm_loadu_si128((const __m128i *)b4.qs);
- auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
- auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
-#ifdef HAVE_FANCY_SIMD
- vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql));
- vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh));
-#else
- auto ql16 = _mm256_cvtepi8_epi16(ql);
- auto qh16 = _mm256_cvtepi8_epi16(qh);
- vk[16*ib+4*i+0] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16)));
- vk[16*ib+4*i+1] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1)));
- vk[16*ib+4*i+2] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16)));
- vk[16*ib+4*i+3] = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1)));
-#endif
- }
- auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux));
- auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
-#ifdef HAVE_FANCY_SIMD
- auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
- vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
- vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
- vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
- vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
- for (int i = 0; i < 4; ++i) {
- vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]);
- vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]);
- }
-#else
- vd[0] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(0, 0, 0, 0));
- vd[1] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(1, 1, 1, 1));
- vd[2] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(2, 2, 2, 2));
- vd[3] = _mm256_shuffle_ps(scales8, scales8, _MM_SHUFFLE(3, 3, 3, 3));
- for (int i = 0; i < 4; ++i) {
- vk[16*ib+4*i+0] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+0]);
- vk[16*ib+4*i+1] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+1]);
- vk[16*ib+4*i+2] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+2]);
- vk[16*ib+4*i+3] = _mm256_mul_ps(vd[i], vk[16*ib+4*i+3]);
- }
-#endif
- }
- } else {
- for (int i = 0; i < D/32; ++i) {
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
- auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
- auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
- auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
-#ifdef HAVE_FANCY_SIMD
- vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
- vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
-#else
- auto ql16 = _mm256_cvtepi8_epi16(ql);
- auto qh16 = _mm256_cvtepi8_epi16(qh);
- vk[4*i+0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))));
- vk[4*i+1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))));
- vk[4*i+2] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))));
- vk[4*i+3] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))));
-#endif
- }
- }
- }
-
+ // Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
auto dl = (const block_q4_0 *)Base::lblock(l1) + j/QK4_0;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto q = vld1q_u8(dl->qs);
+ q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ q = vaddq_s8(q, m8);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
+#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
#ifdef HAVE_FANCY_SIMD
@@ -6721,50 +6928,37 @@ struct HelperQ40 final : public BaseHelper<step> {
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
#endif
+#endif
}
- inline void load_2(int l1, F16::Data * vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+D/F16::block_size);
- }
-
+#ifdef __AVX2__
const __m128i mask = _mm_set1_epi8(0xf);
const __m128i m8 = _mm_set1_epi8(-8);
+#else
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+ const int8x16_t m8 = vdupq_n_s8(-8);
+#endif
};
template <int D, int step>
struct HelperQ41 final : public BaseHelper<step> {
static_assert(step == QK4_1);
using Base = BaseHelper<step>;
+ using block_q8 = block_q8_1;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
-
- inline void load(int l1, F16::Data * vk) const {
- auto dl = (const block_q4_1 *)Base::lblock(l1);
- for (int i = 0; i < D/32; ++i) {
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl[i].d));
- auto vm = F16::set1(GGML_FP16_TO_FP32(dl[i].m));
- auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
- auto ql = _mm_and_si128(q, mask);
- auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
-#ifdef HAVE_FANCY_SIMD
- vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
- vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
-#else
- auto ql16 = _mm256_cvtepi8_epi16(ql);
- auto qh16 = _mm256_cvtepi8_epi16(qh);
- vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(ql16))), vm);
- vk[4*i+1] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(ql16, 1))), vm);
- vk[4*i+2] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(qh16))), vm);
- vk[4*i+3] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(qh16, 1))), vm);
- vk[4*i+0] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(ql)), vm);
-#endif
- }
- }
-
+ // Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
auto dl = (const block_q4_1 *)Base::lblock(l1) + j/QK4_1;
+#ifdef __aarch64__
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto vm = F16::set1(*(const float16_t *)&dl->m);
+ auto q = vld1q_u8(dl->qs);
+ q = (j%QK4_1) ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
+ v1 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_low_u8(q))));
+ v2 = vfmaq_f16(vm, vd, vcvtq_f16_u16(vmovl_u8(vget_high_u8(q))));
+#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
@@ -6779,16 +6973,15 @@ struct HelperQ41 final : public BaseHelper<step> {
v1 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))), vm);
v2 = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))), vm);
#endif
+#endif
}
- inline void load_2(int l1, F16::Data * vk) const {
- load(l1+0, vk+0);
- load(l1+1, vk+D/F16::block_size);
- }
-
+#ifdef __aarch64__
+ const uint8x16_t mask = vdupq_n_u8(0xf);
+#else
const __m128i mask = _mm_set1_epi8(0xf);
-};
#endif
+};
template <int q_step, int k_step>
struct FlashMS {
@@ -6811,8 +7004,51 @@ struct FlashMS {
}
}
+ inline void update_M(int j, float smax) {
+ if (smax == -INFINITY) {
+ std::memset(cache + k_step*j, 0, k_step*sizeof(float));
+ need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
+ return;
+ }
+ need_scaling[j] = 0;
+ if (smax > M[j]) {
+ if (M[j] > -INFINITY) {
+ float m = expf(M[j] - smax);
+ vms[j] = F16::set1(m);
+ need_scaling[j] = 1;
+ S[j] *= m;
+ } else {
+ need_scaling[j] = 2;
+ S[j] = 0;
+ }
+ M[j] = smax;
+ }
+ }
+
#ifdef __aarch64__
- inline void update_M_S(int j, float32x4_t * vk) {
+ inline void update_S(int j, float32x4_t * vk) {
+ auto vm = vdupq_n_f32(M[j]);
+ auto vsum = vdupq_n_f32(0);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = v_expf(vsubq_f32(vk[l], vm));
+ vsum = vaddq_f32(vsum, vk[l]);
+ F16::store(cache + k_step*j + 4*l, vk[l]);
+ }
+ S[j] += vaddvq_f32(vsum);
+ }
+#else
+ inline void update_S(int j, F16::Data * vk) {
+ auto vm = F16::set1(M[j]);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ vk[l] = v_expf(F16::sub(vk[l], vm));
+ F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
+ }
+ S[j] += F16::reduce_add<k_step>(vk);
+ }
+#endif
+
+#ifdef __aarch64__
+ inline float load_and_scale(int j, float32x4_t * vk) {
float32x4_t vmax = vdupq_n_f32(-INFINITY);
// Something goes wrong when storing and manipulating K*Q as fp16.
// It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B).
@@ -6850,37 +7086,40 @@ struct FlashMS {
vmax = vmaxq_f32(vmax, vk[l]);
}
}
-
- float smax = vmaxvq_f32(vmax);
- if (smax == -INFINITY) {
- std::memset(cache + k_step*j, 0, k_step*sizeof(float));
- need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
- return;
+ return vmaxvq_f32(vmax);
+ }
+ inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
+ auto vzero = vdupq_n_f32(0);
+ auto vinf = vdupq_n_f32(-INFINITY);
+ for (int l = 0; l < k_step/8; ++l) {
+ auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l));
+ auto vm1 = vzip1q_u16(vm, vm);
+ auto vm2 = vzip2q_u16(vm, vm);
+ auto kq = vld1q_f32_x2(cache + k_step*j + 8*l);
+ vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
+ vbicq_u32(vinf, vm1)));
+ vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
+ vbicq_u32(vinf, vm2)));
}
- need_scaling[j] = 0;
- if (smax > M[j]) {
- if (M[j] > -INFINITY) {
- float m = expf(M[j] - smax);
- vms[j] = F16::set1(m);
- need_scaling[j] = 1;
- S[j] *= m;
- } else {
- need_scaling[j] = 2;
- S[j] = 0;
+ float32x4_t vmax = vdupq_n_f32(-INFINITY);
+ auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));
+ if (softcap <= 0.0f) {
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vk[l]);
+ vmax = vmaxq_f32(vmax, vk[l]);
+ }
+ } else {
+ auto v_softcap = vdupq_n_f32(softcap);
+ for (int l = 0; l < k_step/4; ++l) {
+ vk[l] = vmulq_f32(vscale32, vk[l]);
+ vk[l] = vmulq_f32(v_softcap, v_tanh(vk[l]));
+ vmax = vmaxq_f32(vmax, vk[l]);
}
- M[j] = smax;
- }
- auto vm = vdupq_n_f32(M[j]);
- auto vsum = vdupq_n_f32(0);
- for (int l = 0; l < k_step/4; ++l) {
- vk[l] = v_expf(vsubq_f32(vk[l], vm));
- vsum = vaddq_f32(vsum, vk[l]);
- F16::store(cache + k_step*j + 4*l, vk[l]);
}
- S[j] += vaddvq_f32(vsum);
+ return vmaxvq_f32(vmax);
}
#else
- inline void update_M_S(int j, F16::Data * vk) {
+ inline float load_and_scale(int j, F16::Data * vk) {
if (softcap <= 0.0f) {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l));
} else {
@@ -6890,32 +7129,67 @@ struct FlashMS {
vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val)));
}
}
-
- float smax = F16::reduce_max<k_step>(vk);
- if (smax == -INFINITY) {
- std::memset(cache + k_step*j, 0, k_step*sizeof(float));
- need_scaling[j] = M[j] == -INFINITY ? 2 : 0;
- return;
- }
- need_scaling[j] = 0;
- if (smax > M[j]) {
- if (M[j] > -INFINITY) {
- float m = expf(M[j] - smax);
- vms[j] = F16::set1(m);
- need_scaling[j] = 1;
- S[j] *= m;
- } else {
- need_scaling[j] = 2;
- S[j] = 0;
+ return F16::reduce_max<k_step>(vk);
+ }
+ inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) {
+#ifdef HAVE_FANCY_SIMD
+ auto vzero = _mm256_set1_epi16(0);
+ auto vinf = _mm512_set1_ps(-INFINITY);
+ if (softcap <= 0) {
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
+ vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l));
+ }
+ } else {
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) {
+ auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
+ vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l))));
}
- M[j] = smax;
}
- auto vm = F16::set1(M[j]);
+#else
+ auto vzero = _mm_set1_epi16(0);
+ auto vinf = F16::set1(-INFINITY);
for (int l = 0; l < k_step/F16::block_size; ++l) {
- vk[l] = v_expf(F16::sub(vk[l], vm));
- F16::store(cache + k_step*j + F16::block_size*l, vk[l]);
+ auto m128 = _mm_loadu_si128((const __m128i *)mask + l);
+ m128 = _mm_cmpeq_epi16(m128, vzero);
+ auto m256 = _mm256_cvtepi16_epi32(m128);
+ auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
+ auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l);
+ vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
+ }
+ if (softcap <= 0) {
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
+ } else {
+ auto v_softcap = F16::set1(softcap);
+ for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
}
- S[j] += F16::reduce_add<k_step>(vk);
+#endif
+ return F16::reduce_max<k_step>(vk);
+ }
+#endif
+
+#ifdef __aarch64__
+ inline void update_M_S(int j, float32x4_t * vk) {
+ float smax = load_and_scale(j, vk);
+ update_M(j, smax);
+ update_S(j, vk);
+ }
+ inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
+ float smax = load_apply_mask_and_scale(j, vk, mask);
+ update_M(j, smax);
+ update_S(j, vk);
+ }
+#else
+ inline void update_M_S(int j, F16::Data * vk) {
+ float smax = load_and_scale(j, vk);
+ update_M(j, smax);
+ update_S(j, vk);
+ }
+ inline void update_M_S(int j, F16::Data * vk, const char * mask) {
+ float smax = load_apply_mask_and_scale(j, vk, mask);
+ update_M(j, smax);
+ update_S(j, vk);
}
#endif
@@ -7200,6 +7474,135 @@ struct FlashQKfp32 {
}
}
#endif
+
+ template <typename KHelper, typename block_q8>
+ static inline void mul_mask_kq(const KHelper& kh, int stride_m,
+ const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ static_assert(q_step <= 8);
+ if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#else
+ F16::Data vk[k_step/F16::block_size];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#endif
+ }
+ template <typename KHelper, typename block_q8>
+ static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
+ const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ GGML_ASSERT(nq < 8);
+ if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+#ifdef __aarch64__
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#else
+ F16::Data vk[k_step/F16::block_size];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+#endif
+ }
};
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
@@ -7259,6 +7662,49 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
}
}
+template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ FlashMS<q_step, k_step>& fms,
+ FlashQKV<D, q_step, k_step>& fqkv,
+ const float * q, const char * mask, float * qkv) {
+ typename KHelper::block_q8 q8[q_step*(D/QK8_0)];
+ for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+ kh.next_block();
+ vh.next_block();
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, stride_qkv, qkv);
+
+ q += q_step*stride_q;
+ mask += q_step*stride_m;
+ qkv += q_step*stride_qkv;
+ }
+ int n_left = nq1 - q_step*(nq1/q_step);
+ if (n_left > 0) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(n_left, vh, fms);
+ kh.next_block();
+ vh.next_block();
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
+ }
+}
+
// Some of the methods in FlashAttn have two identical implementations that only differ by
// one version using a loop over the template parameter q_step, while the other using a loop
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
@@ -7280,8 +7726,14 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
- compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
+ std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ } else {
+ compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
}
FlashMS<q_step, k_step> fms;
@@ -7604,7 +8056,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperF16<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
-#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -7617,7 +8068,6 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ41<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
-#endif
default: break;
}
}
@@ -7633,7 +8083,6 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperF16<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
-#ifdef __AVX2__
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -7646,21 +8095,16 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ41<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
-#endif
default: break;
}
}
inline bool flash_attn_is_supported(ggml_type type) {
-#ifdef __AVX2__
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
-#else
- if (type == GGML_TYPE_F16) return true;
-#endif
return false;
}
}
@@ -7695,7 +8139,8 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
stride_q /= sizeof(float); // q stride as float
#ifdef __AVX512BF16__
- if (type_k == GGML_TYPE_BF16 && type_v == GGML_TYPE_BF16) {
+ if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) {
+ if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types
switch (D) {
case 64:
iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;