summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-29 07:19:43 +0200
committerGitHub <noreply@github.com>2025-04-29 07:19:43 +0200
commitcda24b58cbef34154651d0083910fed860a506c1 (patch)
tree90cd3bd7f772c3b240a6553eca5e50edf95c53da
parentbaeefb4731fb24cdace168f6dbc74516d470efc0 (diff)
CPU FA improvements (#351)
* FA: provide work buffer for K repacking * Add header to avoid comp0iler warnings * WIP * WIP * WIP * WIP * Slightly better * WIP (Zen4) * WIP * Try to improve for unusual number of heads/number of threads * Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA * Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA * Use Sum4q4 for q4_0 * WIP * WIP * Much better FA TG with q8_0 KV cache Just repack it even for TG. But do the repacking for k_step rows, not the whole K tensor. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c29
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp147
-rw-r--r--ggml/src/iqk/iqk_flash_impl.h4
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp708
4 files changed, 763 insertions, 125 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index f3cfd9a0..4cd18a28 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
#if GGML_USE_IQK_MULMAT
+ size_t qsize = 0;
const struct ggml_tensor * q = node->src[0];
const struct ggml_tensor * k = node->src[1];
+ if (k->type == GGML_TYPE_Q8_0) {
+ qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
+ }
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
- int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
+ int gcd = simple_gcd(k->ne[2], n_tasks);
+ int nth_k = n_tasks/gcd;
+ int nek2_k = k->ne[2]/gcd;
+ int nchunk = nek2_k*k->ne[1]/32;
+ int npt = (nchunk + nth_k - 1)/nth_k;
+ int nk;
+ if (npt*nth_k == nchunk) {
+ nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
+ } else {
+ //int nm = std::max(1, npt/8);
+ int nm = 1;
+ while (true) {
+ if (nm*4 >= npt) break;
+ nm *= 2;
+ }
+ nk = 32*nm;
+ }
+ //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
} else {
int nstep_k = k->ne[1]/32;
int gcd_k = simple_gcd(nstep_k, n_tasks);
@@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
size += q->ne[2]*row_size;
}
- cur = MAX(cur, size);
+ cur = MAX(cur, size+qsize);
}
}
+ } else {
+ cur = MAX(cur, qsize);
}
#endif
} break;
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
index 0de68b94..fd0d5dd0 100644
--- a/ggml/src/iqk/iqk_flash_attn.cpp
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
}
return a;
}
+inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) {
+ if (Mj == -INFINITY) return;
+ if (Mj > M) {
+ if (M == -INFINITY) {
+ std::memcpy(Racc, R, Dv*sizeof(float));
+ S = Sj;
+ } else {
+ float c = exp(M - Mj);
+ S = c*S + Sj;
+ for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
+ }
+ M = Mj;
+ } else {
+ float c = exp(Mj - M);
+ S += c*Sj;
+ for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
+ }
+}
}
// TODO: get the ggml_type enum here without polution
@@ -34,7 +52,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int nek3, int nek2, long nbk3, long nbk2,
int nev3, int nev2, long nbv3, long nbv2,
int ne2, int ne1, long nb1,
- int int_type_k, // type of k
+ int int_type_k_in, // type of k
int int_type_v, // type of v
int Dk, // K head size
int Dv, // V head size
@@ -51,7 +69,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
- [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
+ [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
int ith, int nth) {
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
@@ -61,6 +79,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rk3 = neq3/nek3;
int rv3 = neq3/nev3;
+ int int_type_k = int_type_k_in;
+ auto work_buffer = work_buffer_in;
+ if (neq1 >= 8 || rk2 >= 8) {
+ uint64_t row_size = 0;
+ work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
+ if (int_type_k != int_type_k_in) {
+ stride_k = row_size;
+ nbk2 = stride_k*nek1;
+ nbk3 = nbk2*nek2;
+ k = work_buffer_in;
+ barrier(barrier_data);
+ }
+ }
+ //uint64_t row_size = 0;
+ //auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
+ //if (int_type_k != int_type_k_in) {
+ // stride_k = row_size;
+ // nbk2 = stride_k*nek1;
+ // nbk3 = nbk2*nek2;
+ // k = work_buffer_in;
+ // barrier(barrier_data);
+ //}
+
// Getting confused all the time about where to load data from and store the results to
// (especially when combining the results from the threads).
// So, for now, making it work just for MLA (nek2 = 1).
@@ -128,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto Mj = R + Dv*nq_this_j;
auto Sj = Mj + nq_this_j;
R += jj*Dv;
- if (Mj[jj] == -INFINITY) continue;
- if (Mj[jj] > M) {
- if (M == -INFINITY) {
- std::memcpy(Racc, R, Dv*sizeof(float));
- S = Sj[jj];
- } else {
- float c = exp(M - Mj[jj]);
- S = c*S + Sj[jj];
- for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
- }
- M = Mj[jj];
- } else {
- float c = exp(Mj[jj] - M);
- S += c*Sj[jj];
- for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
- }
+ accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
@@ -154,10 +180,72 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
}
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
- int nk = std::max(1, 32 * (nek2*nek1/(32*nth)));
+ auto result_size = (Dv + 16)*rk2*sizeof(float);
+ int gcd = simple_gcd(nek2, nth);
+ if (false && gcd > 1) {
+ int nth_g = nth/gcd;
+ int ith_g = ith%nth_g;
+ int nek1_32 = nek1/32;
+ int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
+ int ith_mid = nth_g;
+ if (nek1_pt*nth_g > nek1_32) {
+ ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
+ }
+ nek1_pt *= 32;
+ int nek1_mid = ith_mid*nek1_pt;
+ int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
+ for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
+ int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
+ auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
+ auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
+ auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
+ auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
+ auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
+ this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
+ scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
+ }
+
+ barrier(barrier_data);
+
+ for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
+ int ik02 = iq2/rk2;
+ int il = iq2 - ik02*rk2;
+ auto Racc = qkv + iq2*nb1/sizeof(float);
+ float M = -INFINITY, S = 0;
+ for (int ig = 0; ig < nth_g; ++ig) {
+ int istep_k = ik02*nth_g + ig;
+ auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
+ const float * R = this_result + il*Dv;
+ const float * Mj = this_result + Dv*rk2;
+ const float * Sj = Mj + rk2;
+ accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
+ }
+ float norm = S > 0 ? 1/S : 1;
+ for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
+ }
+ return true;
+ }
+ int nth_k = nth/gcd;
+ int nek2_k = nek2/gcd;
+ int nchunk = nek2_k*nek1/32;
+ int npt = (nchunk + nth_k - 1)/nth_k;
+ int nk;
+ if (npt*nth_k == nchunk) {
+ nk = 32 * (nek2*nek1/(32*nth));
+ } else {
+ //int nm = std::max(1, npt/8);
+ int nm = 1;
+ while (true) {
+ if (nm*4 >= npt) break;
+ nm *= 2;
+ }
+ nk = 32*nm;
+ }
+ //int nk = 32 * (nek2*nek1/(32*nth));
int nkk = (nek1 + nk - 1)/nk;
int nstep_k = nek2*nkk;
- auto result_size = (Dv + 16)*rk2*sizeof(float);
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
int ik02 = istep_k/nkk;
@@ -183,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int ik02 = iq2/rk2;
int il = iq2 - ik02*rk2;
auto Racc = qkv + iq2*nb1/sizeof(float);
- std::memset(Racc, 0, Dv*sizeof(float));
+ //std::memset(Racc, 0, Dv*sizeof(float));
float M = -INFINITY, S = 0;
for (int ikk = 0; ikk < nkk; ++ikk) {
int istep_k = ik02*nkk + ikk;
@@ -191,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const float * R = this_result + il*Dv;
const float * Mj = this_result + Dv*rk2;
const float * Sj = Mj + rk2;
- if (Mj[il] == -INFINITY) continue;
- if (Mj[il] > M) {
- if (M == -INFINITY) {
- std::memcpy(Racc, R, Dv*sizeof(float));
- S = Sj[il];
- } else {
- float c = exp(M - Mj[il]);
- S = c*S + Sj[il];
- for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
- }
- M = Mj[il];
- } else {
- float c = exp(Mj[il] - M);
- S += c*Sj[il];
- for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
- }
+ accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h
index 68802927..6f62e56b 100644
--- a/ggml/src/iqk/iqk_flash_impl.h
+++ b/ggml/src/iqk/iqk_flash_impl.h
@@ -6,6 +6,8 @@
#pragma once
+#include <cstdint>
+
bool iqk_flash_attn_impl(int type_k, // type of k
int type_v, // type of v
int Dk, // K head size
@@ -27,3 +29,5 @@ bool iqk_flash_attn_impl(int type_k, // type of k
float * M,
float * S);
+void * iqk_repack_k(int type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,
+ const void * k, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index e7ab2e5b..5f916584 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -19,6 +19,7 @@
#include "ggml-quants.h"
#include "iqk_mul_mat.h"
#include "iqk_quantize.h"
+#include "iqk_flash_impl.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -6639,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
}
}
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m512i qx[4];
+ __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
+ float dy[nrc_y];
+ int32_t sy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -64*iptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[8];
+ float dx[8];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int kx = 0; kx < 8; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/32; ++i) {
+ for (int kx = 0; kx < 4; ++kx) {
+ qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
+ _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
+ }
+ auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
+ auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
+ auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
+ auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
+ qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
+ qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
+ qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
+ qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
+ if constexpr (nrc_y <= 4) {
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ } else {
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ }
+ auto scales_x = _mm256_loadu_ps(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ if constexpr (nrc_y <= 4) {
+ auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
+ auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
+ auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
+ auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
+ info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
+ info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
+ } else {
+ acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
+ auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
+ auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
+ auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
+ info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
+ info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
+ acc[iy] = _mm512_setzero_si512();
+ }
+ }
+ }
+}
+#endif
+
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -8208,6 +8287,22 @@ template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct
return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
}
}
+ inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); }
+};
+
+template <typename Q8, typename Q8x4> struct Sum4q4 {
+ inline __m256i compute(const __m256i * qx, const Q8 * y) const {
+ const Q8x4 * y4 = (const Q8x4 *)y;
+ auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0
+ auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1
+ auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2
+ auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3
+ auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1
+ auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3
+ auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3
+ return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123);
+ }
+ inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); }
};
struct ScaleHelperQ8_0 {
@@ -8362,6 +8457,7 @@ struct MinusType0 {
inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
inline float compute(float d, int) const { return d; }
inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
+ inline __m256 vresult(__m256 acc, int) const { return acc; }
};
template <int nrc_y> struct MinusType1 {
@@ -8381,6 +8477,9 @@ template <int nrc_y> struct MinusType1 {
const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
return hsum_float_4(_mm_add_ps(sum, accm[iy]));
}
+ inline __m256 vresult(__m256 acc, int iy) const {
+ return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0));
+ }
};
template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
@@ -8408,7 +8507,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
for (int iy = 0; iy < nrc_y; ++iy) {
auto s12 = scales.prepare1(other_scales, y[iy] + i);
auto d = accm.compute(s12, iy);
- const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
+ const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
}
}
@@ -8417,6 +8516,36 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
info.store(ix, iy, accm.result(acc[iy], iy));
}
}
+ template <typename Unpacker, typename Scales, typename Sum, typename Q8>
+ inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) {
+ auto qx = unp.quants();
+ __m256 dall[nrc_y];
+ for (int i = 0; i < nb/4; ++i) {
+ auto other_scales = unp.set_block_4(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
+ dall[iy] = accm.compute(s12, iy);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto pall = sum.compute(qx, y[iy] + 4*i);
+ acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
+ }
+ }
+ if (!is_multiple_of_4) {
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto other_scales = unp.set_block(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare1(other_scales, y[iy] + i);
+ auto d = accm.compute(s12, iy);
+ const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ result[iy] = accm.vresult(acc[iy], iy);
+ }
+ }
};
template <int nrc_y, bool is_multiple_of_4>
@@ -8425,10 +8554,7 @@ using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
template <int nrc_y, bool is_multiple_of_4>
using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
-using Sum4Type0 = Sum4<block_q8_0, block_q8_0_x4, SignedDot>;
-using Sum4Type1 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot>;
using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
-//using Sum4TypeQ81 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot, false>;
using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>;
template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
@@ -8443,6 +8569,19 @@ void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& in
}
}
+template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
+void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
+ GGML_ASSERT(nrc_x%2 == 0);
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ Scales scales;
+ for (int ix = 0; ix < nrc_x; ix += 2) {
+ unp.set_row(ix);
+ AccumType accum;
+ accum.compute(nb, unp, scales, sum4, y, info, ix);
+ }
+}
+
template <typename Unpacker, int nrc_y>
void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%Unpacker::block_size() == 0);
@@ -8459,6 +8598,63 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info
}
}
+inline __m256 hsum_float_8x8(__m256 * accm) {
+ for (int i = 0; i < 4; ++i) {
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
+ // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
+ }
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+
+template <typename Unpacker, int nrc_y, int nrc_x>
+void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
+ static_assert(8%nrc_y == 0);
+ Q8<nrc_y, block_q8_0> q8(info);
+ int nb = n/Unpacker::block_size();
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ ScaleHelperQ8_0 scales;
+ __m256 result[8];
+ auto store = [&info, &result] (int ix0) {
+ if constexpr (nrc_y == 1) {
+ info.store(ix0, 0, hsum_float_8x8(result));
+ }
+ else if constexpr (nrc_y == 2) {
+ auto value = hsum_float_8x8(result);
+ auto value1 = _mm256_extractf128_ps(value, 1);
+ info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
+ info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
+ }
+ else {
+ float val[8];
+ _mm256_storeu_ps(val, hsum_float_8x8(result));
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
+ }
+ };
+ if (nb%4 == 0) {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType0<nrc_y, true> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ } else {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType0<nrc_y, false> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ }
+}
+
+
template <typename Unpacker, int nrc_y>
void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%Unpacker::block_size() == 0);
@@ -8491,6 +8687,52 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info
}
}
+template <typename Unpacker, int nrc_y, int nrc_x>
+void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
+ static_assert(8%nrc_y == 0);
+ Q8<nrc_y, block_q8_2> q8(info);
+ int nb = n/Unpacker::block_size();
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ ScaleHelperQ8_2 scales;
+ __m256 result[8];
+ auto store = [&info, &result] (int ix0) {
+ if constexpr (nrc_y == 1) {
+ info.store(ix0, 0, hsum_float_8x8(result));
+ }
+ else if constexpr (nrc_y == 2) {
+ auto value = hsum_float_8x8(result);
+ auto value1 = _mm256_extractf128_ps(value, 1);
+ info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
+ info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
+ }
+ else {
+ float val[8];
+ _mm256_storeu_ps(val, hsum_float_8x8(result));
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
+ }
+ };
+ if (nb%4 == 0) {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType1<nrc_y, true> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ } else {
+ for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
+ for (int ix = 0; ix < 8/nrc_y; ++ix) {
+ unp.set_row(ix0 + ix);
+ AccumType1<nrc_y, false> accum;
+ accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
+ }
+ store(ix0);
+ }
+ }
+}
+
struct Dequantizer4bit {
const __m256i m4 = _mm256_set1_epi8(0xf);
inline __m256i dequant(const uint8_t * qs) const {
@@ -8640,7 +8882,8 @@ struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_
};
struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> {
Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- using Sum4T = Sum4TypeQ82;
+ //using Sum4T = Sum4TypeQ82;
+ using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>;
inline static int block_size() { return QK4_0; }
};
#ifdef HAVE_FANCY_SIMD
@@ -15168,6 +15411,13 @@ struct F16 {
auto v256 = _mm256_set_m128(v128, v128);
return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1);
}
+ static inline void set4(const float * ptr, Data * vs) {
+ auto v = set4(ptr);
+ vs[0] = _mm512_shuffle_ps(v, v, 0x00);
+ vs[1] = _mm512_shuffle_ps(v, v, 0x55);
+ vs[2] = _mm512_shuffle_ps(v, v, 0xaa);
+ vs[3] = _mm512_shuffle_ps(v, v, 0xff);
+ }
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); }
@@ -15193,6 +15443,13 @@ struct F16 {
auto v128 = _mm_loadu_ps(ptr);
return _mm256_set_m128(v128, v128);
}
+ static inline void set4(const float * ptr, Data * vs) {
+ auto v = set4(ptr);
+ vs[0] = _mm256_shuffle_ps(v, v, 0x00);
+ vs[1] = _mm256_shuffle_ps(v, v, 0x55);
+ vs[2] = _mm256_shuffle_ps(v, v, 0xaa);
+ vs[3] = _mm256_shuffle_ps(v, v, 0xff);
+ }
static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); }
static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); }
static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); }
@@ -15388,7 +15645,119 @@ struct HelperQ80 final : public BaseHelper<step> {
}
}
};
+}
+
+void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,
+ const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) {
+ repacked_type = int_type_k;
+ auto type_k = ggml_type(int_type_k);
+ if (type_k != GGML_TYPE_Q8_0 || nek0%QK8_0 != 0) return work;
+ int nrows = nek1*nek2*nek3;
+ if (nrows%8 != 0) return work;
+ repacked_type = int(GGML_TYPE_Q8_0_R8);
+ row_size = ggml_row_size(GGML_TYPE_Q8_0, nek0);
+ void * result = (char *)work + nrows*row_size;
+ int npt = 8*((nrows/8 + nth - 1)/nth);
+ int first = npt*ith;
+ if (first >= nrows) return result;
+ int last = std::min(first + npt, nrows);
+ const block_q8_0 * x8[8];
+ auto y = (block_q8_0_r8 *)((char *)work + first*row_size);
+ int nblock = nek0/QK8_0;
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ for (int row = first; row < last; row += 8) {
+ int ik3 = row/(nek1*nek2);
+ int ik2 = (row - ik3*nek1*nek2)/nek1;
+ int ik1 = row - ik3*nek1*nek2 - ik2*nek1;
+ auto this_data = (const char *)data + ik1*nbk1 + ik2*nbk2 + ik3*nbk3;
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(this_data + k*nbk1);
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
+#ifdef __AVX2__
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+ //#ifdef HAVE_FANCY_SIMD
+ // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+ //#endif
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3);
+ m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1));
+ m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1));
+ m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1));
+ m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1));
+ t0 = _mm256_unpacklo_epi32(m0, m1);
+ t1 = _mm256_unpacklo_epi32(m2, m3);
+ t2 = _mm256_unpackhi_epi32(m0, m1);
+ t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+ //#ifdef HAVE_FANCY_SIMD
+ // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+ //#endif
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
+#elif defined __ARM_NEON
+ for (int l = 0; l < 2; ++l) {
+ m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
+ m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
+ m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
+ m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
+ vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
+ vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
+ vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
+ }
+#else
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ y += nblock;
+ }
+ return result;
+}
+namespace {
template <int D, int step>
struct HelperQ80R8 : public BaseHelper<step> {
using Base = BaseHelper<step>;
@@ -15399,24 +15768,21 @@ struct HelperQ80R8 : public BaseHelper<step> {
constexpr static int block_size_q = QK8_0;
using block_q8 = block_q8_0;
#endif
+ HelperQ80R8(const char * data, int stride) : Base(data, stride) {}
HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
r4 = repack(nk, q8);
Base::data = (const char *)r4.data();
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
- static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
- static_assert(D%QK8_0 == 0);
- GGML_ASSERT(nk%8 == 0);
+ static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) {
constexpr int nblock = D/QK8_0;
- std::vector<block_q8_0_r8> result(nblock * nk/8);
- auto y = result.data();
const block_q8_0 * x8[8];
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
for (int row = 0; row < nk; row += 8) {
- for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride);
for (int ib = 0; ib < nblock; ++ib) {
for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
#ifdef __AVX2__
@@ -15498,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper<step> {
}
y += nblock;
}
+ }
+
+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
+ static_assert(D%QK8_0 == 0);
+ GGML_ASSERT(nk%8 == 0);
+ constexpr int nblock = D/QK8_0;
+ std::vector<block_q8_0_r8> result(nblock * nk/8);
+ auto y = result.data();
+ repack(nk, q8.data, q8.stride, y);
return result;
}
@@ -15952,12 +16327,13 @@ struct FlashMS {
}
return F16::reduce_max<k_step>(vk);
}
- static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) {
- auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
- m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
- auto m256 = _mm256_cvtepi16_epi32(m128);
- auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
- return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
+ static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) {
+ return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l)));
+ //auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
+ //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
+ //auto m256 = _mm256_cvtepi16_epi32(m128);
+ //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
+ //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
}
#ifdef __AVX512F__
static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
@@ -16087,7 +16463,6 @@ struct FlashQKV {
accumulate_qkv_1(vh, fms);
return;
}
- F16::Data v[8];
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
if (fms.need_scaling[j] == 2) {
@@ -16100,6 +16475,43 @@ struct FlashQKV {
}
}
}
+#ifdef __AVX512F__
+ if constexpr ((D/F16::block_size)%4 == 0) {
+ F16::Data v[16];
+ F16::Data vs[4];
+ for (int i = 0; i < D/F16::block_size; i += 4) {
+ for (int l = 0; l < k_step; l += 4) {
+ for (int k = 0; k < 4; ++k) {
+ vh.load(l+k, i+0, v[4*k+0], v[4*k+1]);
+ vh.load(l+k, i+2, v[4*k+2], v[4*k+3]);
+ }
+ for (int j = 0; j < q_step; ++j) {
+ auto R = qkv_cache + D*j;
+ auto s1 = F16::load(R + F16::block_size*(i+0));
+ auto s2 = F16::load(R + F16::block_size*(i+1));
+ auto s3 = F16::load(R + F16::block_size*(i+2));
+ auto s4 = F16::load(R + F16::block_size*(i+3));
+ F16::set4(fms.cache + k_step*j + l, vs);
+ for (int k = 0; k < 4; ++k) {
+ s1 = F16::fmadd(s1, v[4*k+0], vs[k]);
+ s2 = F16::fmadd(s2, v[4*k+1], vs[k]);
+ s3 = F16::fmadd(s3, v[4*k+2], vs[k]);
+ s4 = F16::fmadd(s4, v[4*k+3], vs[k]);
+ }
+ F16::store(R + F16::block_size*(i+0), s1);
+ F16::store(R + F16::block_size*(i+1), s2);
+ F16::store(R + F16::block_size*(i+2), s3);
+ F16::store(R + F16::block_size*(i+3), s4);
+ }
+ }
+ }
+ return;
+ }
+#endif
+ F16::Data v[8];
+#ifdef __AVX2__
+ F16::Data vs[4];
+#endif
for (int i = 0; i < D/F16::block_size; i += 2) {
for (int l = 0; l < k_step; l += 4) {
vh.load(l+0, i, v[0], v[4]);
@@ -16110,6 +16522,13 @@ struct FlashQKV {
auto R = qkv_cache + D*j;
auto s1 = F16::load(R + F16::block_size*(i+0));
auto s2 = F16::load(R + F16::block_size*(i+1));
+#ifdef __AVX2__
+ F16::set4(fms.cache + k_step*j + l, vs);
+ for (int k = 0; k < 4; ++k) {
+ s1 = F16::fmadd(s1, v[k+0], vs[k]);
+ s2 = F16::fmadd(s2, v[k+4], vs[k]);
+ }
+#else
auto vs = F16::set4(fms.cache + k_step*j + l);
s1 = F16::fmadd_lane0(s1, v[0], vs);
s2 = F16::fmadd_lane0(s2, v[4], vs);
@@ -16119,6 +16538,7 @@ struct FlashQKV {
s2 = F16::fmadd_lane2(s2, v[6], vs);
s1 = F16::fmadd_lane3(s1, v[3], vs);
s2 = F16::fmadd_lane3(s2, v[7], vs);
+#endif
F16::store(R + F16::block_size*(i+0), s1);
F16::store(R + F16::block_size*(i+1), s2);
}
@@ -16239,7 +16659,8 @@ struct FlashQKV {
// As a result, we get an infinite stream of warnings about uninitialized variable use (one for each
// combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of
// constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization.
- qkv_cache_t qkv_cache[D*q_step] = {};
+ qkv_cache_t qkv_cache[D*q_step]; // = {};
+ //qkv_cache_t * qkv_cache;
};
template <int D, int q_step, int k_step>
@@ -16481,8 +16902,14 @@ struct FlashQKfp32 {
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
#ifdef HAVE_FANCY_SIMD
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
#endif
#endif
@@ -16493,10 +16920,15 @@ struct FlashQKfp32 {
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#else
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
#ifdef HAVE_FANCY_SIMD
- if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ if constexpr (D%32 == 0 && k_step%8 == 0) {
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
+ } else {
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ }
#endif
- if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#endif
}
@@ -16514,17 +16946,23 @@ struct FlashQKfp32 {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
#endif
}
-#if GGML_IQK_FA_ALL_QUANTS
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
#else
+ if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
+ if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
+ if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
#endif
}
+#if GGML_IQK_FA_ALL_QUANTS
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
@@ -16664,8 +17102,29 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv,
- float * M, float * S) {
- typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
+ float * M, float * S, char * qptr) {
+ auto q8 = (typename KHelper::block_q8 *)qptr;
+ if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
+ if (nq1 == q_step) {
+ fms.init_qstep();
+ kh.reset_block();
+ vh.reset_block();
+ block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
+ HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
+ HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
+ auto mr = mask;
+ for (int k1 = 0; k1 < nk1/k_step; ++k1) {
+ HelperQ80R8<Dk, k_step>::repack(k_step, kh.data, kh.stride, q8r8);
+ KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms);
+ fqkv.accumulate_qkv(vh, fms);
+ kh.next_block();
+ vh.next_block();
+ mr += k_step*sizeof(ggml_half);
+ }
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
+ return;
+ }
+ }
#if FA_TIMING
Perf perf(false);
#endif
@@ -16731,6 +17190,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
#endif
}
+char * get_q_storage(size_t size) {
+ thread_local std::vector<char> q_storage;
+ if (q_storage.size() < size) q_storage.resize(size);
+ return q_storage.data();
+}
+
// Some of the methods in FlashAttn have two identical implementations that only differ by
// one version using a loop over the template parameter q_step, while the other using a loop
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
@@ -16753,44 +17218,57 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
- if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
+ if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
- std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
- compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
- if (nq1 >= 8) {
+ std::is_same_v<KHelper, HelperQ60<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ80R8<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
+ constexpr size_t kMaxOnStackSize = 576;
+ auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
+ q_size = GGML_PAD(q_size, 64);
+ if (q_size > kMaxOnStackSize) {
+ auto qptr = get_q_storage(q_size);
+ if (nq1 >= 8) {
+ if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
#if FA_TIMING
- auto t1 = Perf::cur_time();
- HelperQ80R8<Dk, k_step> khr4(nk1, kh);
- Perf::instance().accum(4, t1);
+ auto t1 = Perf::cur_time();
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
#else
- HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
- } else{
- compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
- }
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
- if (nq1 >= 8) {
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+ return;
+
+ }
+ if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
#if FA_TIMING
- auto t1 = Perf::cur_time();
- HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
- Perf::instance().accum(4, t1);
+ auto t1 = Perf::cur_time();
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
#else
- HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
- } else{
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+ return;
+ }
+ }
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
+
}
- } else {
+ else {
+ typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
+ }
+ }
+ else {
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
@@ -17234,39 +17712,61 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
- if (nk1 >= 256) { //4096) {
+ auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
+ nq1 -= n;
+ if (nq1 == 0) return true;
+ q += n*stride_q;
+ mask += n*stride_m;
+ qkv += n*stride_qkv;
+ if (M && S) { M += n; S += n; }
+ return false;
+ };
+ if (nk1 >= 512) {
+ if (nq1 >= 128) {
+ int n_step = nq1/128;
+ FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(128*n_step)) return;
+ }
if (nq1 >= 64) {
+ int n_step = nq1/64;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- return;
+ fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(64*n_step)) return;
}
if (nq1 >= 32) {
+ int n_step = nq1/32;
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- return;
+ fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(32*n_step)) return;
}
if (nq1 >= 16) {
+ int n_step = nq1/16;
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- return;
+ fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(16*n_step)) return;
}
}
if (nq1 >= 8) {
+ int n_step = nq1/8;
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(8*n_step)) return;
}
else if (nq1 >= 4) {
+ int n_step = nq1/4;
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(4*n_step)) return;
}
else if (nq1 >= 2) {
+ int n_step = nq1/2;
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
- }
- else {
- FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ if (update(2*n_step)) return;
}
+ FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
#ifdef __AVX512BF16__
@@ -17327,11 +17827,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ60<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
-#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
+#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_1: {
HelperQ41<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
@@ -17360,6 +17860,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
+ case GGML_TYPE_Q8_0_R8: {
+ HelperQ80R8<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
+ } break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
@@ -17368,11 +17872,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ60<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
-#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
+#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_1: {
HelperQ41<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
@@ -17393,9 +17897,10 @@ inline bool flash_attn_is_supported(ggml_type type) {
#endif
#if GGML_IQK_FA_ALL_QUANTS
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
- type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
+ type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true;
#else
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true;
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8
+ || type == GGML_TYPE_Q4_0) return true;
#endif
return false;
}
@@ -17404,25 +17909,35 @@ template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
+ auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
+ nq1 -= n;
+ if (nq1 == 0) return true;
+ q += n*stride_q;
+ mask += n*stride_m;
+ qkv += n*stride_qkv;
+ if (M && S) { M += n; S += n; }
+ return false;
+ };
if (nq1 >= 8) {
+ int n_step = nq1/8;
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(8*n_step)) return;
}
- else if (nq1 >= 4) {
+ if (nq1 >= 4) {
+ int n_step = nq1/4;
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(4*n_step)) return;
}
- else {
- FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- }
- //if (nq1 % 8 == 0) {
- // FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
- // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- //} else {
- // FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
- // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
- //}
+ if (nq1 >= 2) {
+ int n_step = nq1/2;
+ FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
+ if (update(2*n_step)) return;
+ }
+ FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
template <int step_k>
@@ -17436,6 +17951,12 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
+ if (type_k == GGML_TYPE_Q8_0_R8) {
+ HelperQ80R8<576, step_k> kh((const char *)k, stride_k);
+ HelperQ80<512, step_k> vh((const char *)v, stride_v);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
+ return true;
+ }
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60<576, step_k> kh((const char *)k, stride_k);
HelperQ60<512, step_k> vh((const char *)v, stride_v);
@@ -17558,6 +18079,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
}
#endif
+ if (nk1%128 == 0) {
+ switch (Dk) {
+ case 64:
+ iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
+ case 96:
+ iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
+ case 128:
+ iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
+ case 192:
+ iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
+ case 256:
+ iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
+ default:
+ return false;
+ }
+ return true;
+ }
if (nk1%64 == 0) {
switch (Dk) {
case 64: