summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-20 08:57:38 +0200
committerGitHub <noreply@github.com>2025-01-20 08:57:38 +0200
commit3c5f87225f0ddd379ab712ddb8ad0013c10167c2 (patch)
tree7f339e1e1fe99218065a297cbf2632dcce8804a9
parent0b74397d596bbcdfba27299393406d2b6330b133 (diff)
More Flash Attention improvements (#173)
* FA: slightly faster V*softmax(K*Q)) on Zen4 * FA: it is also faster on AVX2 and ARM_NEON * Deleted forgotten commented out code * FA: slightly faster V*softmax(K*Q)) also for fp16 K-cache * FA: slightly faster V*softmax(K*Q)) on Zen4 We now get 130.9 t/s for a context of 32k tokens. * FA: don't store sum scaling factor in SIMD registers * FA: timing * FA: faster q8_0 cache via run-time-repacking On Zen4 q8_0 KV-cache now slightly outperforms BF16. We get 134 t/s for 32k tokens, which is ~30% better than the main branch, and ~18% better than the last commit. We simply repack the K-cache to q8_0_r4 before the K*Q multiplication and use the q8_0_r4 x q8_0_x4 matrix multiplication template. * FA: Fix AVX2 * FA: fix ARN_NEON * FA: vectorize q8_0 -> q8_0_r4 repacking also on NEON * FA: dedicated mat mul for D = 128 also for ARM_NEON * FA: turn off performance timer --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c23
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp1021
2 files changed, 841 insertions, 203 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index bcb8bf41..b3c8a951 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -17471,25 +17471,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
#if GGML_USE_IQK_MULMAT
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
- int64_t work_per_slice = D*nek1*neq1;
- int ntg = 1;
+ // I keep changing my mind what is the best strategy to split the threads when processing
+ // multiple heads. This is my current thinking, the commented out code below was the previous.
+ int ntg = nth/simple_gcd(neq2*neq3, nth);
+ int64_t neq1g = (neq1 + ntg - 1)/ntg;
+ //int64_t work_per_slice = D*nek1*neq1;
+ //int ntg = 1;
//
// When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
// But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
// the number of threads processing the (iq2, iq3) matrix.
//
- if (neq1 >= 8*nth) {
- if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
- else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
- else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
- }
+ //if (neq1 >= 8*nth) {
+ // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
+ // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
+ // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
+ //}
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % (nth/ntg) == ith/ntg) {
- int iq1 = (ith%ntg)*neq1/ntg;
+ int iq1 = (ith%ntg)*neq1g;
+ int this_neq1 = MIN(neq1g, neq1-iq1);
if (!iqk_flash_attn_noalibi(k->type, v->type,
- D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+ D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 5577ea99..109ac08e 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -17,6 +17,7 @@
#include <cstring>
#include <type_traits>
+#include <vector>
#if defined IQK_IMPLEMENT
@@ -47,8 +48,57 @@
// For fp16/fp32 matri multiplications tiling is used to improve
// performance.
+#define FA_TIMING 0
+
#include <utility>
#include <array>
+#if FA_TIMING
+#include <chrono>
+#include <mutex>
+struct Perf {
+ using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
+ std::array<double, 5> times = {};
+ std::mutex mutex;
+ bool report;
+ static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
+ inline void accum(int what, const TimePoint& t1) {
+ auto t2 = cur_time();
+ auto dt = delta(t1, t2);
+ std::lock_guard<std::mutex> lock(mutex);
+ times[what] += dt;
+ }
+ inline void accum_nolock(int what, const TimePoint& t1) {
+ auto t2 = cur_time();
+ auto dt = delta(t1, t2);
+ times[what] += dt;
+ }
+ inline void add(const Perf& other) {
+ std::lock_guard<std::mutex> lock(mutex);
+ for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
+ }
+ Perf(bool r) : report(r) {}
+ ~Perf() {
+ if (report) {
+ double tot = 0;
+ for (auto& t : times) tot += t;
+ if (!tot) return;
+ printf("======================= Timing: %g ms in total\n", tot);
+ for (int i = 0; i < int(times.size()); ++i) {
+ if (times[i]) {
+ printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
+ }
+ }
+ }
+ }
+ static Perf& instance() {
+ static Perf p(true);
+ return p;
+ }
+ static double delta(const TimePoint& t1, const TimePoint& t2) {
+ return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
+ }
+};
+#endif
#ifdef _MSC_VER
#define IQK_NOINLINE __declspec(noinline)
@@ -6895,6 +6945,25 @@ struct QFBase {
static inline Data load4Floats(const Float * x) {
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
}
+ static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
+ acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
+ acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
+ auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
+ acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline __m128 hsum_r4(Acc acc) {
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
+ return _mm_add_ps(sum1, sum2);
+ }
#else
constexpr static int k_step = 8;
using Data = __m256;
@@ -6904,12 +6973,29 @@ struct QFBase {
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return _mm256_fmadd_ps(y, x, prev);
}
+ static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
+ acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
+ acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
+ static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
+ auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
+ acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
+ acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
+ acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
+ return acc;
+ }
static inline Acc acc_first(const Data& y, const Data& x) {
return _mm256_mul_ps(y, x);
}
static inline float hsum(Acc acc) {
return hsum_float_8(acc);
}
+ static inline __m128 hsum_r4(Acc acc) {
+ return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
+ }
template <typename Float>
static inline Data load4Floats(const Float * x) {
return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
@@ -6928,6 +7014,31 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
+ IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
+ xv[0] = load1(ix+0, i);
+ xv[1] = load1(ix+1, i);
+ xv[2] = load1(ix+2, i);
+ xv[3] = load1(ix+3, i);
+#ifdef HAVE_FANCY_SIMD
+ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
+ auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
+ auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
+ auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
+ xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
+ xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
+ xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
+ xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
+#else
+ auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
+ auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
+ auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
+ auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
+ xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
+ xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+ xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
+#endif
+ }
const Float * y[nrc];
};
@@ -6973,6 +7084,56 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
}
+template <typename Qy, typename Qx>
+inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ int nb = n/QFBase::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
+}
+
+template <typename Qy, typename Qx>
+inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ static_assert(Qx::nrc%4 == 0);
+ int nb = D/QFBase::k_step;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
+ for (int i = 0; i < nb; ++i) {
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
+ for (int iy = 0; iy < Qy::nrc; ++iy) {
+ auto yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) {
+ for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
+ }
+}
+
// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
@@ -11902,6 +12063,46 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+template <int nrc_y>
+void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(n == 128);
+ int8x16x4_t qx[8];
+ float32x4_t scales[4];
+ float32x4_t scales_y[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
+ for (int k = 0; k < 4; ++k) {
+ scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
+ qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
+ qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto by = (const block_q8_0_x4 *)info.src1_row(iy);
+ auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
+ scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
+ scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
+ scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
+ scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
+ auto sumf = vdupq_n_f32(0.f);
+ for (int k = 0; k < 4; ++k) {
+ auto y = vld1q_s8_x2(by->qs+32*k);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
+ sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
+ }
+ info.store(ix, iy, sumf);
+ }
+ }
+}
+
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
@@ -12354,6 +12555,15 @@ struct F16 {
static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); }
static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); }
static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); }
+ static inline Data set4(const float * ptr) {
+ auto v128 = _mm_loadu_ps(ptr);
+ auto v256 = _mm256_set_m128(v128, v128);
+ return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
+ static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
+ static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
+ static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); }
#elif defined __AVX2__
using Data = __m256;
constexpr static int block_size = 8;
@@ -12371,6 +12581,14 @@ struct F16 {
static inline float reduce_add(Data data) { return hsum_float_8(data); }
static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); }
static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); }
+ static inline Data set4(const float * ptr) {
+ auto v128 = _mm_loadu_ps(ptr);
+ return _mm256_set_m128(v128, v128);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
+ static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
+ static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
+ static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); }
#else
using Data = float16x8_t;
constexpr static int block_size = 8;
@@ -12402,6 +12620,14 @@ struct F16 {
}
static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); }
static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); }
+ static inline float16x4_t set4(const float * ptr) {
+ auto val32 = vld1q_f32(ptr);
+ return vcvt_f16_f32(val32);
+ }
+ static inline Data fmadd_lane0(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 0); }
+ static inline Data fmadd_lane1(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 1); }
+ static inline Data fmadd_lane2(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 2); }
+ static inline Data fmadd_lane3(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 3); }
#endif
template <int k_step> static inline float reduce_max(const Data * data) {
return reduce_T<k_step, &F16::max, &F16::reduce_max>(data);
@@ -12454,7 +12680,6 @@ template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef HAVE_FANCY_SIMD
- //using block_q8 = block_q8_1;
using block_q8 = block_q8_1;
#else
using block_q8 = block_q8_0;
@@ -12478,14 +12703,14 @@ struct HelperQ80 final : public BaseHelper<step> {
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
#else
int ii = j%QK8_0;
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+0))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+1))));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+0)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+8)))));
#endif
#endif
}
static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
- GGML_ASSERT(nq <= step);
+ //GGML_ASSERT(nq <= step); Why did I have this assert?
for (int i = 0; i < nq; ++i) {
quantize_row_q8_0_x4(q, y, D);
q += stride_q;
@@ -12494,7 +12719,7 @@ struct HelperQ80 final : public BaseHelper<step> {
}
static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
- GGML_ASSERT(nq <= step);
+ //GGML_ASSERT(nq <= step); Why did I have this assert?
for (int i = 0; i < nq; ++i) {
quantize_row_q8_1_x4(q, y, D);
q += stride_q;
@@ -12504,6 +12729,86 @@ struct HelperQ80 final : public BaseHelper<step> {
};
template <int D, int step>
+struct HelperQ80R4 : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+#ifdef __AVX2__
+ using block_q8 = block_q8_1;
+#else
+ using block_q8 = block_q8_0;
+#endif
+ HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
+ r4 = repack(nk, q8);
+ Base::data = (const char *)r4.data();
+ Base::stride = (D/QK8_0)*sizeof(block_q8_0);
+ }
+
+ static std::vector<block_q8_0_x4> repack(int nk, const HelperQ80<D, step> q8) {
+ static_assert(D%QK8_0 == 0);
+ GGML_ASSERT(nk%4 == 0);
+ constexpr int nblock = D/QK8_0;
+ std::vector<block_q8_0_x4> result(nblock * nk/4);
+ auto y = result.data();
+ const block_q8_0 * x4[4];
+ for (int row = 0; row < nk; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d;
+#ifdef __AVX2__
+ auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs);
+ auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs);
+ auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs);
+ auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs);
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3);
+#elif defined __ARM_NEON
+ auto m0 = vld1q_s8_x2(x4[0][ib].qs);
+ auto m1 = vld1q_s8_x2(x4[1][ib].qs);
+ auto m2 = vld1q_s8_x2(x4[2][ib].qs);
+ auto m3 = vld1q_s8_x2(x4[3][ib].qs);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ib].qs + 0, m0);
+ vst1q_s8_x2(y[ib].qs + 32, m1);
+ vst1q_s8_x2(y[ib].qs + 64, m2);
+ vst1q_s8_x2(y[ib].qs + 96, m3);
+#else
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ y += nblock;
+ }
+ return result;
+ }
+
+ std::vector<block_q8_0_x4> r4;
+};
+
+template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
@@ -12725,7 +13030,7 @@ struct FlashMS {
if (smax > M[j]) {
if (M[j] > -INFINITY) {
float m = expf(M[j] - smax);
- vms[j] = F16::set1(m);
+ vms[j] = m;
need_scaling[j] = 1;
S[j] *= m;
} else {
@@ -12907,7 +13212,7 @@ struct FlashMS {
cache_t cache[q_step*k_step];
float S[q_step], M[q_step];
int need_scaling[q_step];
- F16::Data vms[q_step];
+ float vms[q_step];
const F16::Data vscale;
const float softcap;
const ggml_half h_inf;
@@ -12927,79 +13232,90 @@ struct FlashQKV {
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- F16::Data vk[2*q_step];
- for (int i = 0; i < D/F16::block_size; i += 2) {
- for (int j = 0; j < q_step; ++j) {
- if (fms.need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = F16::zero();
- } else {
- auto R = qkv_cache + D*j;
- vk[2*j+0] = F16::load(R + F16::block_size*i);
- vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
- if (fms.need_scaling[j] == 1) {
- vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
- vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
- }
+ F16::Data v[8];
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ if (fms.need_scaling[j] == 2) {
+ std::memset(R, 0, D*sizeof(qkv_cache_t));
+ }
+ else if (fms.need_scaling[j] == 1) {
+ auto vms = F16::set1(fms.vms[j]);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
}
}
- F16::Data v1, v2, v3, v4;
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- vh.load(l1+0, i, v1, v2);
- vh.load(l1+1, i, v3, v4);
+ }
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ for (int l = 0; l < k_step; l += 4) {
+ vh.load(l+0, i, v[0], v[4]);
+ vh.load(l+1, i, v[1], v[5]);
+ vh.load(l+2, i, v[2], v[6]);
+ vh.load(l+3, i, v[3], v[7]);
for (int j = 0; j < q_step; ++j) {
- auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
- auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
- vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
- vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+ auto vs = F16::set4(fms.cache + k_step*j + l);
+ s1 = F16::fmadd_lane0(s1, v[0], vs);
+ s2 = F16::fmadd_lane0(s2, v[4], vs);
+ s1 = F16::fmadd_lane1(s1, v[1], vs);
+ s2 = F16::fmadd_lane1(s2, v[5], vs);
+ s1 = F16::fmadd_lane2(s1, v[2], vs);
+ s2 = F16::fmadd_lane2(s2, v[6], vs);
+ s1 = F16::fmadd_lane3(s1, v[3], vs);
+ s2 = F16::fmadd_lane3(s2, v[7], vs);
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
}
}
- for (int j = 0; j < q_step; ++j) {
- auto R = qkv_cache + D*j;
- F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
- F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
- }
}
}
- template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
+ template <typename VHelper>
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
- F16::Data vk[2*q_step];
- for (int i = 0; i < D/F16::block_size; i += 2) {
- for (int j = 0; j < nq1; ++j) {
- if (fms.need_scaling[j] == 2) {
- vk[2*j+0] = vk[2*j+1] = F16::zero();
- } else {
- auto R = qkv_cache + D*j;
- vk[2*j+0] = F16::load(R + F16::block_size*i);
- vk[2*j+1] = F16::load(R + F16::block_size*(i + 1));
- if (fms.need_scaling[j] == 1) {
- vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]);
- vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]);
- }
+ F16::Data v[8];
+ for (int j = 0; j < nq1; ++j) {
+ auto R = qkv_cache + D*j;
+ if (fms.need_scaling[j] == 2) {
+ std::memset(R, 0, D*sizeof(qkv_cache_t));
+ }
+ else if (fms.need_scaling[j] == 1) {
+ auto vms = F16::set1(fms.vms[j]);
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i)));
}
}
- F16::Data v1, v2, v3, v4;
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- vh.load(l1+0, i, v1, v2);
- vh.load(l1+1, i, v3, v4);
+ }
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ for (int l = 0; l < k_step; l += 4) {
+ vh.load(l+0, i, v[0], v[4]);
+ vh.load(l+1, i, v[1], v[5]);
+ vh.load(l+2, i, v[2], v[6]);
+ vh.load(l+3, i, v[3], v[7]);
for (int j = 0; j < nq1; ++j) {
- auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
- auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
- vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
- vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+ auto vs = F16::set4(fms.cache + k_step*j + l);
+ s1 = F16::fmadd_lane0(s1, v[0], vs);
+ s2 = F16::fmadd_lane0(s2, v[4], vs);
+ s1 = F16::fmadd_lane1(s1, v[1], vs);
+ s2 = F16::fmadd_lane1(s2, v[5], vs);
+ s1 = F16::fmadd_lane2(s1, v[2], vs);
+ s2 = F16::fmadd_lane2(s2, v[6], vs);
+ s1 = F16::fmadd_lane3(s1, v[3], vs);
+ s2 = F16::fmadd_lane3(s2, v[7], vs);
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
}
}
- for (int j = 0; j < nq1; ++j) {
- auto R = qkv_cache + D*j;
- F16::store(R + F16::block_size*(i + 0), vk[2*j+0]);
- F16::store(R + F16::block_size*(i + 1), vk[2*j+1]);
- }
}
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
+ //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@@ -13024,156 +13340,281 @@ struct FlashQKV {
}
}
- qkv_cache_t qkv_cache[D*q_step];
+ // qkv_cache_t qkv_cache[D*q_step];
+ // The initializer is not actually required. But the compiler cannot figure out that when qkv_cache is
+ // first used for q_step rows, fms.need_scaling[j] is always 2, which zeroes the content of qkv_cache.
+ // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each
+ // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of
+ // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization.
+ qkv_cache_t qkv_cache[D*q_step] = {};
};
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n == 128);
+ //Q8<nrc_y, block_q8_1_x4> q8(info);
+ __m512i qx[16];
+ __m512 scales[4];
+ __m512 scales_m[4];
+ __m512 dy[4];
+ auto m127 = _mm512_set1_epi8(127);
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx);
+ const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx);
+ for (int k = 0; k < 4; ++k) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f));
+ qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)),
+ _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1);
+ qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)),
+ _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1);
+ qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)),
+ _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1);
+ qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)),
+ _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1);
+ qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127);
+ qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127);
+ qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127);
+ qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto by = (const block_q8_1_x4 *)info.src1_row(iy);
+ //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d));
+ auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d));
+ auto d128 = _mm256_castps256_ps128(dall);
+ auto m128 = _mm256_extractf128_ps(dall, 1);
+ auto m256 = _mm256_set_m128(m128, m128);
+ auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1);
+ auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00));
+ sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf);
+ sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf);
+ sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf);
+ auto d256 = _mm256_set_m128(d128, d128);
+ auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1);
+ dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00));
+ dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55));
+ dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa));
+ dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff));
+ for (int k = 0; k < 4; ++k) {
+ //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k);
+ auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf);
+ }
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+}
+#endif
+
template <int D, int q_step, int k_step>
struct FlashQKfp32 {
static_assert(D%F16::block_size == 0 && D <= 256);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
-#ifdef __aarch64__
- constexpr static bool is_small_head = false;
+#ifdef __AVX2__
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
+ FlashMS<q_step, k_step>& fms) {
+#ifdef HAVE_FANCY_SIMD
+ constexpr int nrc_q = 8;
+ constexpr int nrc_k = 8;
#else
- constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
+ // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
+ constexpr int nrc_q = 4;
+ constexpr int nrc_k = 8;
#endif
-
- template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
- static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
- F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
- if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
- return;
- }
- auto qr = q + m1*stride_q;
- for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i);
- if (mp[l1+0] != fms.h_inf) {
- auto vsum = F16::zero();
- for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]);
- fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum);
- }
- if (mp[l1+1] != fms.h_inf) {
- auto vsum = F16::zero();
- for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]);
- fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum);
+ constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
+ constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ for (int iq = 0; iq < q_step/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_Qx_Qy_MxN_fa<QFT<q_float, nrc_q>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info);
+ }
+ info.cur_y += nrc_q;
}
- }
-
- template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
- static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
- F16::Data * vk, FlashMS<q_step, k_step>& fms) {
- // q index is q_step*i1 + m1
- // k index is k_step*k1 + l1
- const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
- if (mp[l1] == fms.h_inf) {
- fms.cache[k_step*m1 + l1] = -INFINITY;
- return;
+ if constexpr (qrem > 0) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_Qx_Qy_MxN_fa<QFT<q_float, qrem>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info);
+ }
}
- auto qr = q + m1*stride_q;
- auto vsum = F16::zero();
- for (int i = 0; i < D/F16::block_size; ++i) {
- vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i));
+ F16::Data vk[k_step/F16::block_size];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
}
- fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
}
-
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
- static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
+#else
+ template <typename KHelper, typename q_float>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
- F16::Data qv[D/F16::block_size];
- F16::Data vk[D/(F16::block_size/2)];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vk);
- for (int m1 = 0; m1 < q_step; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
+ constexpr int nrc_q = 4;
+ constexpr int nrc_k = 6;
+ constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
+ constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ for (int iq = 0; iq < q_step/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
}
- }
- }
-
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
- static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
- const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- F16::Data vk[D/F16::block_size];
- for (int l1 = 0; l1 < k_step; ++l1) {
- kh.load(l1, vk);
- for (int m1 = 0; m1 < q_step; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
}
+ info.cur_y += nrc_q;
}
- }
-
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
- static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
- FlashMS<q_step, k_step>& fms) {
- F16::Data qv[D/F16::block_size];
- F16::Data vk[D/(F16::block_size/2)];
- for (int l1 = 0; l1 < k_step; l1 += 2) {
- kh.load_2(l1, vk);
- for (int m1 = 0; m1 < nq; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
+ if constexpr (qrem > 0) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<qrem, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
}
- }
- }
-
- template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
- static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
- const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- F16::Data vk[D/F16::block_size];
- for (int l1 = 0; l1 < k_step; ++l1) {
- kh.load(l1, vk);
- for (int m1 = 0; m1 < nq; ++m1) {
- mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<qrem, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
}
}
+ float32x4_t vk[k_step/4];
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
}
+#endif
+#ifdef __AVX2__
template <typename KHelper, typename q_float>
- static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
- if constexpr (is_small_head) {
- mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
- }
- else {
- mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
+#ifdef HAVE_FANCY_SIMD
+ constexpr int nrc_q = 8;
+ constexpr int nrc_k = 8;
+#else
+ // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
+ constexpr int nrc_q = 4;
+ constexpr int nrc_k = 8;
+#endif
+ static_assert(k_step%nrc_k == 0);
+ int qrem = q_step - nrc_q*(q_step/nrc_q);
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ info.cur_y += nrc_q;
}
-#ifdef __aarch64__
- float32x4_t vk[k_step/4];
- for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk);
+ if (qrem > 0) {
+ switch (qrem) {
+ case 1: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+ case 2: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+ case 3: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+#ifdef HAVE_FANCY_SIMD
+ case 4: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+ case 5: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+ case 6: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+ case 7: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ } break;
+#endif
+ }
}
-#else
F16::Data vk[k_step/F16::block_size];
for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk);
+ fms.update_M_S(j, vk, mask + stride_m*j);
}
-#endif
}
-
+#else
template <typename KHelper, typename q_float>
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
- if constexpr (is_small_head) {
- mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
+ constexpr int nrc_q = 4;
+ constexpr int nrc_k = 6;
+ constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
+ const int qrem = q_step - nrc_q*(q_step/nrc_q);
+ DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
+ }
+ info.cur_y += nrc_q;
}
- else {
- mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
+ switch (qrem) {
+ case 0: break;
+ case 1: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
+ }
+ } break;
+ case 2: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
+ }
+ } break;
+ case 3: {
+ for (int ik = 0; ik < k_step/nrc_k; ++ik) {
+ mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
+ }
+ if constexpr (krem > 0) {
+ mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
+ }
+ } break;
}
-#ifdef __aarch64__
float32x4_t vk[k_step/4];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk);
- }
-#else
- F16::Data vk[k_step/F16::block_size];
- for (int j = 0; j < nq; ++j) {
- fms.update_M_S(j, vk);
+ for (int j = 0; j < q_step; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
}
-#endif
}
+#endif
#ifdef __aarch64__
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
@@ -13206,6 +13647,19 @@ struct FlashQKfp32 {
case 7: return std::make_pair(mul_mat, 7>, 7);\
}\
}
+#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat<1>, 1);\
+ case 2: return std::make_pair(mul_mat<2>, 2);\
+ case 3: return std::make_pair(mul_mat<3>, 3);\
+ case 4: return std::make_pair(mul_mat<4>, 4);\
+ case 5: return std::make_pair(mul_mat<5>, 5);\
+ case 6: return std::make_pair(mul_mat<6>, 6);\
+ case 7: return std::make_pair(mul_mat<7>, 7);\
+ }\
+ }
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
@@ -13229,6 +13683,48 @@ struct FlashQKfp32 {
}
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
+#ifdef __aarch64__
+ if constexpr (D == 128) {
+ if (q_step >= 64 && nq >= 64) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
+ }
+ else if (q_step >= 32 && nq >= 32) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
+ }
+ else if (q_step >= 16 && nq >= 16) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
+ }
+ else {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
+ }
+ } else {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
+ }
+ //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
+#else
+#ifdef HAVE_FANCY_SIMD
+ if constexpr (D == 128) {
+ if (q_step >= 64 && nq >= 64) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64);
+ }
+ else if (q_step >= 32 && nq >= 32) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32);
+ }
+ else if (q_step >= 16 && nq >= 16) {
+ return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16);
+ }
+ else {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq);
+ }
+ } else {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
+ }
+#else
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
+#endif
+#endif
+ }
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
@@ -13373,20 +13869,44 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashQKV<D, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
typename KHelper::block_q8 q8[q_step*(D/QK8_0)];
+#if FA_TIMING
+ Perf perf(false);
+#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+#endif
fms.init_qstep();
kh.reset_block();
vh.reset_block();
HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8);
+#if FA_TIMING
+ perf.accum_nolock(0, t1);
+#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#if FA_TIMING
+ t1 = Perf::cur_time();
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
+ perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
fqkv.accumulate_qkv(vh, fms);
+ perf.accum_nolock(2, t1);
+#else
+ KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+#endif
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
+#if FA_TIMING
+ t1 = Perf::cur_time();
fqkv.normalize_and_store(fms, stride_qkv, qkv);
+ perf.accum_nolock(3, t1);
+#else
+ fqkv.normalize_and_store(fms, stride_qkv, qkv);
+#endif
q += q_step*stride_q;
mask += q_step*stride_m;
@@ -13408,6 +13928,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
}
+#if FA_TIMING
+ Perf::instance().add(perf);
+#endif
}
// Some of the methods in FlashAttn have two identical implementations that only differ by
@@ -13431,11 +13954,38 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
+// if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
+// std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
+// std::is_same_v<KHelper, HelperQ80<D, k_step>> ||
+// std::is_same_v<KHelper, HelperQ80R4<D, k_step>> ||
+// std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+// compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+// } else {
+// compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+// }
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
- std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
+ std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ if (nq1 >= 8) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+ HelperQ80R4<D, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
+#else
+ HelperQ80R4<D, k_step> khr4(nk1, kh);
+#endif
+ compute_helper_q<D, q_step, k_step, HelperQ80R4<D, k_step>, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ } else{
+ compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
} else {
compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
@@ -13475,6 +14025,10 @@ struct HelperBF16 final : public BaseHelper<step> {
load(l1+2, vk+2*D/32);
load(l1+3, vk+3*D/32);
}
+
+ inline void load_8(int l1, __m512bh * vk) const {
+ for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32);
+ }
};
template <int D, int q_step, int k_step>
@@ -13627,6 +14181,29 @@ struct FlashQKbf16 {
_mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
}
+ static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+ }
+
+ static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ __m256 sum[8];
+ for (int k = 0; k < 8; ++k) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ }
+ _mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum));
+ }
+
static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
__m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
auto qr = q + m1*D;
@@ -13639,16 +14216,23 @@ struct FlashQKbf16 {
fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
+#if FA_TIMING
+ template <typename KHelper>
+ static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) {
+ auto t1 = Perf::cur_time();
+#else
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
+#endif
{
__m512bh qv[D/32];
if constexpr (D <= 128) {
- __m512bh vkh[D/8];
- for (int l1 = 0; l1 < k_step; l1 += 4) {
- kh.load_4(l1, vkh);
- for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
+ __m512bh vkh[D/4];
+ for (int l1 = 0; l1 < k_step; l1 += 8) {
+ kh.load_8(l1, vkh);
+ for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms);
}
} else {
__m512bh vkh[D/16];
@@ -13658,10 +14242,17 @@ struct FlashQKbf16 {
}
}
}
+#if FA_TIMING
+ perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
+#endif
F16::Data vk[k_step/16];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk, mask + stride_m*j);
}
+#if FA_TIMING
+ perf.accum_nolock(2, t1);
+#endif
}
template <typename KHelper>
@@ -13747,20 +14338,44 @@ struct FlashAttnBF16 {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
ggml_bf16_t q_bf16[q_step*D];
+#if FA_TIMING
+ Perf perf(false);
+#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+#endif
fms.init_qstep();
kh.reset_block();
vh.reset_block();
FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16);
+#if FA_TIMING
+ perf.accum_nolock(0, t1);
+#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+#if FA_TIMING
+ //t1 = Perf::cur_time();
+ FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
+ //perf.accum_nolock(1, t1);
+ t1 = Perf::cur_time();
+ fqkv.accumulate_qkv(vh, fms);
+ perf.accum_nolock(3, t1);
+#else
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(vh, fms);
+#endif
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
+#if FA_TIMING
+ t1 = Perf::cur_time();
+#endif
fqkv.normalize_and_store(fms, stride_qkv, qkv);
+#if FA_TIMING
+ perf.accum_nolock(4, t1);
+#endif
q += q_step*stride_q;
mask += q_step*stride_m;
@@ -13782,6 +14397,9 @@ struct FlashAttnBF16 {
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
}
+#if FA_TIMING
+ Perf::instance().add(perf);
+#endif
}
FlashMS<q_step, k_step> fms;
@@ -13793,23 +14411,21 @@ template <int D, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
-#if defined __AVX2__
- constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>;
-#else
- constexpr bool kUseLargeStepsQ = true;
-#endif
- if constexpr (kUseLargeStepsQ) {
- if (nk1 >= 4096) {
- if (nq1 >= 32) {
- FlashAttn<D, 32, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
- return;
- }
- else if (nq1 >= 8) {
- FlashAttn<D, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
- return;
- }
+ if (nk1 >= 256) { //4096) {
+ if (nq1 >= 64) {
+ FlashAttn<D, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
+ }
+ if (nq1 >= 32) {
+ FlashAttn<D, 32, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
+ }
+ if (nq1 >= 16) {
+ FlashAttn<D, 16, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
}
}
if (nq1 >= 8) {
@@ -13833,12 +14449,13 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int
if (nq1 >= 64) {
FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
}
else if (nq1 >= 16) {
FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
}
- return;
}
if (nq1 >= 8) {
FlashAttnBF16<D, 8, k_step> fa(scale, softcap);
@@ -13968,6 +14585,22 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
+ if (nk1%64 == 0) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
+ switch (D) {
+ case 64:
+ iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 96:
+ iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 128:
+ iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 256:
+ iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ default:
+ return false;
+ }
+ return true;
+ }
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
switch (D) {
case 64: