diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-01-20 08:57:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-20 08:57:38 +0200 |
commit | 3c5f87225f0ddd379ab712ddb8ad0013c10167c2 (patch) | |
tree | 7f339e1e1fe99218065a297cbf2632dcce8804a9 | |
parent | 0b74397d596bbcdfba27299393406d2b6330b133 (diff) |
More Flash Attention improvements (#173)
* FA: slightly faster V*softmax(K*Q)) on Zen4
* FA: it is also faster on AVX2 and ARM_NEON
* Deleted forgotten commented out code
* FA: slightly faster V*softmax(K*Q)) also for fp16 K-cache
* FA: slightly faster V*softmax(K*Q)) on Zen4
We now get 130.9 t/s for a context of 32k tokens.
* FA: don't store sum scaling factor in SIMD registers
* FA: timing
* FA: faster q8_0 cache via run-time-repacking
On Zen4 q8_0 KV-cache now slightly outperforms BF16.
We get 134 t/s for 32k tokens, which is ~30% better than
the main branch, and ~18% better than the last commit.
We simply repack the K-cache to q8_0_r4 before the K*Q
multiplication and use the q8_0_r4 x q8_0_x4 matrix multiplication
template.
* FA: Fix AVX2
* FA: fix ARN_NEON
* FA: vectorize q8_0 -> q8_0_r4 repacking also on NEON
* FA: dedicated mat mul for D = 128 also for ARM_NEON
* FA: turn off performance timer
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 23 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 1021 |
2 files changed, 841 insertions, 203 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bcb8bf41..b3c8a951 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17471,25 +17471,30 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - int64_t work_per_slice = D*nek1*neq1; - int ntg = 1; + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int64_t neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; // // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of // the number of threads processing the (iq2, iq3) matrix. // - if (neq1 >= 8*nth) { - if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - } + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} int counter = 0; for (int64_t iq3 = 0; iq3 < neq3; iq3++) { for (int64_t iq2 = 0; iq2 < neq2; iq2++) { if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1/ntg; + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = MIN(neq1g, neq1-iq1); if (!iqk_flash_attn_noalibi(k->type, v->type, - D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5577ea99..109ac08e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17,6 +17,7 @@ #include <cstring> #include <type_traits> +#include <vector> #if defined IQK_IMPLEMENT @@ -47,8 +48,57 @@ // For fp16/fp32 matri multiplications tiling is used to improve // performance. +#define FA_TIMING 0 + #include <utility> #include <array> +#if FA_TIMING +#include <chrono> +#include <mutex> +struct Perf { + using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>; + std::array<double, 5> times = {}; + std::mutex mutex; + bool report; + static auto cur_time() { return std::chrono::high_resolution_clock::now(); } + inline void accum(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + std::lock_guard<std::mutex> lock(mutex); + times[what] += dt; + } + inline void accum_nolock(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + times[what] += dt; + } + inline void add(const Perf& other) { + std::lock_guard<std::mutex> lock(mutex); + for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; + } + Perf(bool r) : report(r) {} + ~Perf() { + if (report) { + double tot = 0; + for (auto& t : times) tot += t; + if (!tot) return; + printf("======================= Timing: %g ms in total\n", tot); + for (int i = 0; i < int(times.size()); ++i) { + if (times[i]) { + printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); + } + } + } + } + static Perf& instance() { + static Perf p(true); + return p; + } + static double delta(const TimePoint& t1, const TimePoint& t2) { + return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count(); + } +}; +#endif #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) @@ -6895,6 +6945,25 @@ struct QFBase { static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00)); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline __m128 hsum_r4(Acc acc) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3)); + return _mm_add_ps(sum1, sum2); + } #else constexpr static int k_step = 8; using Data = __m256; @@ -6904,12 +6973,29 @@ struct QFBase { static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00)); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } static inline Acc acc_first(const Data& y, const Data& x) { return _mm256_mul_ps(y, x); } static inline float hsum(Acc acc) { return hsum_float_8(acc); } + static inline __m128 hsum_r4(Acc acc) { + return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); + } template <typename Float> static inline Data load4Floats(const Float * x) { return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); @@ -6928,6 +7014,31 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase { } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } + IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const { + xv[0] = load1(ix+0, i); + xv[1] = load1(ix+1, i); + xv[2] = load1(ix+2, i); + xv[3] = load1(ix+3, i); +#ifdef HAVE_FANCY_SIMD + auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); + xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); +#else + auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); + xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); +#endif + } const Float * y[nrc]; }; @@ -6973,6 +7084,56 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } +template <typename Qy, typename Qx> +inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + int nb = n/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); +} + +template <typename Qy, typename Qx> +inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) { + static_assert(Qx::nrc%4 == 0); + int nb = D/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {}; + for (int i = 0; i < nb; ++i) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix); + for (int iy = 0; iy < Qy::nrc; ++iy) { + auto yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy])); + } +} + // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. @@ -11902,6 +12063,46 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } +template <int nrc_y> +void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n == 128); + int8x16x4_t qx[8]; + float32x4_t scales[4]; + float32x4_t scales_y[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int k = 0; k < 4; ++k) { + scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); + qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); + qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto by = (const block_q8_0_x4 *)info.src1_row(iy); + auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); + scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); + scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); + scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); + scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); + auto sumf = vdupq_n_f32(0.f); + for (int k = 0; k < 4; ++k) { + auto y = vld1q_s8_x2(by->qs+32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); + sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); + } + info.store(ix, iy, sumf); + } + } +} + #define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ m.funcs[0] = func<Dequantizer, 1>;\ m.funcs[1] = func<Dequantizer, 2>;\ @@ -12354,6 +12555,15 @@ struct F16 { static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); } static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + auto v256 = _mm256_set_m128(v128, v128); + return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); } #elif defined __AVX2__ using Data = __m256; constexpr static int block_size = 8; @@ -12371,6 +12581,14 @@ struct F16 { static inline float reduce_add(Data data) { return hsum_float_8(data); } static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + return _mm256_set_m128(v128, v128); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); } #else using Data = float16x8_t; constexpr static int block_size = 8; @@ -12402,6 +12620,14 @@ struct F16 { } static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); } static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); } + static inline float16x4_t set4(const float * ptr) { + auto val32 = vld1q_f32(ptr); + return vcvt_f16_f32(val32); + } + static inline Data fmadd_lane0(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 0); } + static inline Data fmadd_lane1(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 1); } + static inline Data fmadd_lane2(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 2); } + static inline Data fmadd_lane3(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 3); } #endif template <int k_step> static inline float reduce_max(const Data * data) { return reduce_T<k_step, &F16::max, &F16::reduce_max>(data); @@ -12454,7 +12680,6 @@ template <int D, int step> struct HelperQ80 final : public BaseHelper<step> { using Base = BaseHelper<step>; #ifdef HAVE_FANCY_SIMD - //using block_q8 = block_q8_1; using block_q8 = block_q8_1; #else using block_q8 = block_q8_0; @@ -12478,14 +12703,14 @@ struct HelperQ80 final : public BaseHelper<step> { v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else int ii = j%QK8_0; - v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+0)))); - v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+1)))); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+0))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+8))))); #endif #endif } static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) { - GGML_ASSERT(nq <= step); + //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { quantize_row_q8_0_x4(q, y, D); q += stride_q; @@ -12494,7 +12719,7 @@ struct HelperQ80 final : public BaseHelper<step> { } static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) { - GGML_ASSERT(nq <= step); + //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { quantize_row_q8_1_x4(q, y, D); q += stride_q; @@ -12504,6 +12729,86 @@ struct HelperQ80 final : public BaseHelper<step> { }; template <int D, int step> +struct HelperQ80R4 : public BaseHelper<step> { + using Base = BaseHelper<step>; +#ifdef __AVX2__ + using block_q8 = block_q8_1; +#else + using block_q8 = block_q8_0; +#endif + HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) { + r4 = repack(nk, q8); + Base::data = (const char *)r4.data(); + Base::stride = (D/QK8_0)*sizeof(block_q8_0); + } + + static std::vector<block_q8_0_x4> repack(int nk, const HelperQ80<D, step> q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%4 == 0); + constexpr int nblock = D/QK8_0; + std::vector<block_q8_0_x4> result(nblock * nk/4); + auto y = result.data(); + const block_q8_0 * x4[4]; + for (int row = 0; row < nk; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; +#ifdef __AVX2__ + auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs); + auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs); + auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs); + auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); +#elif defined __ARM_NEON + auto m0 = vld1q_s8_x2(x4[0][ib].qs); + auto m1 = vld1q_s8_x2(x4[1][ib].qs); + auto m2 = vld1q_s8_x2(x4[2][ib].qs); + auto m3 = vld1q_s8_x2(x4[3][ib].qs); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0, m0); + vst1q_s8_x2(y[ib].qs + 32, m1); + vst1q_s8_x2(y[ib].qs + 64, m2); + vst1q_s8_x2(y[ib].qs + 96, m3); +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; + } + + std::vector<block_q8_0_x4> r4; +}; + +template <int D, int step> struct HelperQ40 final : public BaseHelper<step> { using Base = BaseHelper<step>; using block_q8 = block_q8_0; @@ -12725,7 +13030,7 @@ struct FlashMS { if (smax > M[j]) { if (M[j] > -INFINITY) { float m = expf(M[j] - smax); - vms[j] = F16::set1(m); + vms[j] = m; need_scaling[j] = 1; S[j] *= m; } else { @@ -12907,7 +13212,7 @@ struct FlashMS { cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; - F16::Data vms[q_step]; + float vms[q_step]; const F16::Data vscale; const float softcap; const ggml_half h_inf; @@ -12927,79 +13232,90 @@ struct FlashQKV { // Hence, for now, we will not handle head sizes of 80 and 112 template <typename VHelper> inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) { - F16::Data vk[2*q_step]; - for (int i = 0; i < D/F16::block_size; i += 2) { - for (int j = 0; j < q_step; ++j) { - if (fms.need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = F16::zero(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = F16::load(R + F16::block_size*i); - vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); - if (fms.need_scaling[j] == 1) { - vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); - vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); - } + F16::Data v[8]; + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + if (fms.need_scaling[j] == 2) { + std::memset(R, 0, D*sizeof(qkv_cache_t)); + } + else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } - F16::Data v1, v2, v3, v4; - for (int l1 = 0; l1 < k_step; l1 += 2) { - vh.load(l1+0, i, v1, v2); - vh.load(l1+1, i, v3, v4); + } + for (int i = 0; i < D/F16::block_size; i += 2) { + for (int l = 0; l < k_step; l += 4) { + vh.load(l+0, i, v[0], v[4]); + vh.load(l+1, i, v[1], v[5]); + vh.load(l+2, i, v[2], v[6]); + vh.load(l+3, i, v[3], v[7]); for (int j = 0; j < q_step; ++j) { - auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); - auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); - vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); - vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto vs = F16::set4(fms.cache + k_step*j + l); + s1 = F16::fmadd_lane0(s1, v[0], vs); + s2 = F16::fmadd_lane0(s2, v[4], vs); + s1 = F16::fmadd_lane1(s1, v[1], vs); + s2 = F16::fmadd_lane1(s2, v[5], vs); + s1 = F16::fmadd_lane2(s1, v[2], vs); + s2 = F16::fmadd_lane2(s2, v[6], vs); + s1 = F16::fmadd_lane3(s1, v[3], vs); + s2 = F16::fmadd_lane3(s2, v[7], vs); + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); } } - for (int j = 0; j < q_step; ++j) { - auto R = qkv_cache + D*j; - F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); - F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); - } } } - template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>> + template <typename VHelper> inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) { - F16::Data vk[2*q_step]; - for (int i = 0; i < D/F16::block_size; i += 2) { - for (int j = 0; j < nq1; ++j) { - if (fms.need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = F16::zero(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = F16::load(R + F16::block_size*i); - vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); - if (fms.need_scaling[j] == 1) { - vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); - vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); - } + F16::Data v[8]; + for (int j = 0; j < nq1; ++j) { + auto R = qkv_cache + D*j; + if (fms.need_scaling[j] == 2) { + std::memset(R, 0, D*sizeof(qkv_cache_t)); + } + else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } - F16::Data v1, v2, v3, v4; - for (int l1 = 0; l1 < k_step; l1 += 2) { - vh.load(l1+0, i, v1, v2); - vh.load(l1+1, i, v3, v4); + } + for (int i = 0; i < D/F16::block_size; i += 2) { + for (int l = 0; l < k_step; l += 4) { + vh.load(l+0, i, v[0], v[4]); + vh.load(l+1, i, v[1], v[5]); + vh.load(l+2, i, v[2], v[6]); + vh.load(l+3, i, v[3], v[7]); for (int j = 0; j < nq1; ++j) { - auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); - auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); - vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); - vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto vs = F16::set4(fms.cache + k_step*j + l); + s1 = F16::fmadd_lane0(s1, v[0], vs); + s2 = F16::fmadd_lane0(s2, v[4], vs); + s1 = F16::fmadd_lane1(s1, v[1], vs); + s2 = F16::fmadd_lane1(s2, v[5], vs); + s1 = F16::fmadd_lane2(s1, v[2], vs); + s2 = F16::fmadd_lane2(s2, v[6], vs); + s1 = F16::fmadd_lane3(s1, v[3], vs); + s2 = F16::fmadd_lane3(s2, v[7], vs); + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); } } - for (int j = 0; j < nq1; ++j) { - auto R = qkv_cache + D*j; - F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); - F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); - } } } inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); + //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -13024,156 +13340,281 @@ struct FlashQKV { } } - qkv_cache_t qkv_cache[D*q_step]; + // qkv_cache_t qkv_cache[D*q_step]; + // The initializer is not actually required. But the compiler cannot figure out that when qkv_cache is + // first used for q_step rows, fms.need_scaling[j] is always 2, which zeroes the content of qkv_cache. + // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each + // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of + // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. + qkv_cache_t qkv_cache[D*q_step] = {}; }; +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n == 128); + //Q8<nrc_y, block_q8_1_x4> q8(info); + __m512i qx[16]; + __m512 scales[4]; + __m512 scales_m[4]; + __m512 dy[4]; + auto m127 = _mm512_set1_epi8(127); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f)); + qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1); + qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1); + qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1); + qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1); + qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127); + qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127); + qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127); + qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto by = (const block_q8_1_x4 *)info.src1_row(iy); + //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d)); + auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d)); + auto d128 = _mm256_castps256_ps128(dall); + auto m128 = _mm256_extractf128_ps(dall, 1); + auto m256 = _mm256_set_m128(m128, m128); + auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1); + auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00)); + sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf); + sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf); + sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf); + auto d256 = _mm256_set_m128(d128, d128); + auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1); + dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00)); + dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55)); + dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa)); + dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff)); + for (int k = 0; k < 4; ++k) { + //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k); + auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf); + } + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#endif + template <int D, int q_step, int k_step> struct FlashQKfp32 { static_assert(D%F16::block_size == 0 && D <= 256); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); -#ifdef __aarch64__ - constexpr static bool is_small_head = false; +#ifdef __AVX2__ + template <typename KHelper, typename q_float> + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + FlashMS<q_step, k_step>& fms) { +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; #else - constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; #endif - - template <bool small = is_small_head, class = std::enable_if<small>, typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; - if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { - return; - } - auto qr = q + m1*stride_q; - for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i); - if (mp[l1+0] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]); - fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum); - } - if (mp[l1+1] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]); - fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum); + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa<QFT<q_float, nrc_q>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; } - } - - template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * vk, FlashMS<q_step, k_step>& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - if (mp[l1] == fms.h_inf) { - fms.cache[k_step*m1 + l1] = -INFINITY; - return; + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa<QFT<q_float, qrem>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info); + } } - auto qr = q + m1*stride_q; - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) { - vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i)); + F16::Data vk[k_step/F16::block_size]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); } - fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum); } - - template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float> - static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, +#else + template <typename KHelper, typename q_float> + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); } - } - } - - template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float> - static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info); } + info.cur_y += nrc_q; } - } - - template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float> - static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS<q_step, k_step>& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<qrem, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); } - } - } - - template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float> - static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<qrem, krem, true>(D, kh.block, kh.stride, k_step - krem, info); } } + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } } +#endif +#ifdef __AVX2__ template <typename KHelper, typename q_float> - static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { - if constexpr (is_small_head) { - mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); - } - else { - mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; +#else + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; +#endif + static_assert(k_step%nrc_k == 0); + int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + info.cur_y += nrc_q; } -#ifdef __aarch64__ - float32x4_t vk[k_step/4]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + if (qrem > 0) { + switch (qrem) { + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#ifdef HAVE_FANCY_SIMD + case 4: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 5: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 6: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 7: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#endif + } } -#else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); } -#endif } - +#else template <typename KHelper, typename q_float> static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) { - if constexpr (is_small_head) { - mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + const int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; } - else { - mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); + switch (qrem) { + case 0: break; + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; } -#ifdef __aarch64__ float32x4_t vk[k_step/4]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); - } -#else - F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); } -#endif } +#endif #ifdef __aarch64__ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { @@ -13206,6 +13647,19 @@ struct FlashQKfp32 { case 7: return std::make_pair(mul_mat, 7>, 7);\ }\ } +#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat<1>, 1);\ + case 2: return std::make_pair(mul_mat<2>, 2);\ + case 3: return std::make_pair(mul_mat<3>, 3);\ + case 4: return std::make_pair(mul_mat<4>, 4);\ + case 5: return std::make_pair(mul_mat<5>, 5);\ + case 6: return std::make_pair(mul_mat<6>, 6);\ + case 7: return std::make_pair(mul_mat<7>, 7);\ + }\ + } if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); @@ -13229,6 +13683,48 @@ struct FlashQKfp32 { } #endif } + else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) { +#ifdef __aarch64__ + if constexpr (D == 128) { + if (q_step >= 64 && nq >= 64) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); + } + else if (q_step >= 32 && nq >= 32) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); + } + else if (q_step >= 16 && nq >= 16) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); + } + else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); + } + } else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + } + //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); +#else +#ifdef HAVE_FANCY_SIMD + if constexpr (D == 128) { + if (q_step >= 64 && nq >= 64) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64); + } + else if (q_step >= 32 && nq >= 32) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32); + } + else if (q_step >= 16 && nq >= 16) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16); + } + else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq); + } + } else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); + } +#else + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); +#endif +#endif + } else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq); @@ -13373,20 +13869,44 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashQKV<D, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; +#if FA_TIMING + Perf perf(false); +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { +#if FA_TIMING + auto t1 = Perf::cur_time(); +#endif fms.init_qstep(); kh.reset_block(); vh.reset_block(); HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8); +#if FA_TIMING + perf.accum_nolock(0, t1); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#if FA_TIMING + t1 = Perf::cur_time(); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(2, t1); +#else + KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); +#endif kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } +#if FA_TIMING + t1 = Perf::cur_time(); fqkv.normalize_and_store(fms, stride_qkv, qkv); + perf.accum_nolock(3, t1); +#else + fqkv.normalize_and_store(fms, stride_qkv, qkv); +#endif q += q_step*stride_q; mask += q_step*stride_m; @@ -13408,6 +13928,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); } +#if FA_TIMING + Perf::instance().add(perf); +#endif } // Some of the methods in FlashAttn have two identical implementations that only differ by @@ -13431,11 +13954,38 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { +// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || +// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || +// std::is_same_v<KHelper, HelperQ80<D, k_step>> || +// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> || +// std::is_same_v<KHelper, HelperQ60<D, k_step>>) { +// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( +// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); +// } else { +// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( +// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); +// } if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> || - std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || + std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> || std::is_same_v<KHelper, HelperQ60<D, k_step>>) { compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } + else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { + if (nq1 >= 8) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ80R4<D, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ80R4<D, k_step> khr4(nk1, kh); +#endif + compute_helper_q<D, q_step, k_step, HelperQ80R4<D, k_step>, VHelper, FlashQKfp32<D, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } else{ + compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } } else { compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); @@ -13475,6 +14025,10 @@ struct HelperBF16 final : public BaseHelper<step> { load(l1+2, vk+2*D/32); load(l1+3, vk+3*D/32); } + + inline void load_8(int l1, __m512bh * vk) const { + for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32); + } }; template <int D, int q_step, int k_step> @@ -13627,6 +14181,29 @@ struct FlashQKbf16 { _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum)); } + static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); + } + + static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + __m256 sum[8]; + for (int k = 0; k < 8; ++k) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum)); + } + static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q, __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { auto qr = q + m1*D; @@ -13639,16 +14216,23 @@ struct FlashQKbf16 { fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); } +#if FA_TIMING + template <typename KHelper> + static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) { + auto t1 = Perf::cur_time(); +#else template <typename KHelper> static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS<q_step, k_step>& fms) { +#endif { __m512bh qv[D/32]; if constexpr (D <= 128) { - __m512bh vkh[D/8]; - for (int l1 = 0; l1 < k_step; l1 += 4) { - kh.load_4(l1, vkh); - for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms); + __m512bh vkh[D/4]; + for (int l1 = 0; l1 < k_step; l1 += 8) { + kh.load_8(l1, vkh); + for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms); } } else { __m512bh vkh[D/16]; @@ -13658,10 +14242,17 @@ struct FlashQKbf16 { } } } +#if FA_TIMING + perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); +#endif F16::Data vk[k_step/16]; for (int j = 0; j < q_step; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); } +#if FA_TIMING + perf.accum_nolock(2, t1); +#endif } template <typename KHelper> @@ -13747,20 +14338,44 @@ struct FlashAttnBF16 { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { ggml_bf16_t q_bf16[q_step*D]; +#if FA_TIMING + Perf perf(false); +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { +#if FA_TIMING + auto t1 = Perf::cur_time(); +#endif fms.init_qstep(); kh.reset_block(); vh.reset_block(); FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16); +#if FA_TIMING + perf.accum_nolock(0, t1); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#if FA_TIMING + //t1 = Perf::cur_time(); + FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + //perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); + fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(3, t1); +#else FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); +#endif kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } +#if FA_TIMING + t1 = Perf::cur_time(); +#endif fqkv.normalize_and_store(fms, stride_qkv, qkv); +#if FA_TIMING + perf.accum_nolock(4, t1); +#endif q += q_step*stride_q; mask += q_step*stride_m; @@ -13782,6 +14397,9 @@ struct FlashAttnBF16 { } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); } +#if FA_TIMING + Perf::instance().add(perf); +#endif } FlashMS<q_step, k_step> fms; @@ -13793,23 +14411,21 @@ template <int D, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { -#if defined __AVX2__ - constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>; -#else - constexpr bool kUseLargeStepsQ = true; -#endif - if constexpr (kUseLargeStepsQ) { - if (nk1 >= 4096) { - if (nq1 >= 32) { - FlashAttn<D, 32, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; - } - else if (nq1 >= 8) { - FlashAttn<D, 8, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; - } + if (nk1 >= 256) { //4096) { + if (nq1 >= 64) { + FlashAttn<D, 64, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + if (nq1 >= 32) { + FlashAttn<D, 32, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + if (nq1 >= 16) { + FlashAttn<D, 16, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; } } if (nq1 >= 8) { @@ -13833,12 +14449,13 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int if (nq1 >= 64) { FlashAttnBF16<D, 64, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; } else if (nq1 >= 16) { FlashAttnBF16<D, 16, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; } - return; } if (nq1 >= 8) { FlashAttnBF16<D, 8, k_step> fa(scale, softcap); @@ -13968,6 +14585,22 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k #ifdef __AVX512BF16__ if (type_k == GGML_TYPE_BF16) { + if (nk1%64 == 0) { + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types + switch (D) { + case 64: + iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + return true; + } if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (D) { case 64: |