summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml.c6
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp405
-rw-r--r--ggml/src/iqk/iqk_mul_mat.h4
3 files changed, 337 insertions, 78 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 771bc8ca..45fddca5 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -16150,8 +16150,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
#if GGML_USE_IQK_MULMAT
- if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
- mask && mask->type == GGML_TYPE_F16) {
+ if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
int64_t work_per_slice = D*nek1*neq1;
int ntg = 1;
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
@@ -16165,7 +16164,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1/ntg;
- if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+ if (!iqk_flash_attn_noalibi(k->type, v->type,
+ D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 7dd817b9..13e6420b 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6057,6 +6057,206 @@ inline __m256 v_tanh(__m256 x) {
namespace {
+template <int k_step>
+struct BaseHelper {
+ BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
+
+ inline void set_block(int k1) { block = data + k1*k_step*stride; }
+ inline void reset_block() { block = data; }
+ inline void next_block() { block += k_step*stride; }
+ inline const char * lblock(int l1) const { return block + l1*stride; }
+
+ const char * data;
+ const char * block;
+ int stride;
+
+};
+
+template <int D, int step>
+struct HelperF16 final : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ HelperF16(const char * data, int stride) : Base(data, stride) {}
+
+ inline void load(int l1, __m512 * vk) const {
+ auto dr = Base::lblock(l1);
+ for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i));
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dr = (const ggml_half *)Base::lblock(l1);
+ v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 0));
+ v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dr + i + 1));
+ }
+
+ inline void load_2(int l1, __m512 * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/16);
+ }
+
+};
+
+template <int D, int step>
+struct HelperQ80 final : public BaseHelper<step> {
+ static_assert(step == QK8_0);
+ using Base = BaseHelper<step>;
+ HelperQ80(const char * data, int stride) : Base(data, stride) {}
+
+ //inline void load(int l1, __m512 * vk) const {
+ // auto dl = (const block_q8_0 *)Base::lblock(l1);
+ // for (int i = 0; i < D/32; ++i) {
+ // auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
+ // vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+0))));
+ // vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl[i].qs+1))));
+ // }
+ //}
+ inline void load(int l1, __m512 * vk) const {
+ auto dl = (const block_q8_0_x4 *)Base::lblock(l1);
+ if constexpr (D >= 128) {
+ __m512 vd[4];
+ for (int ib = 0; ib < D/128; ++ib) {
+ const auto& b8 = dl[ib];
+ auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)b8.d));
+ auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
+ vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
+ vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
+ vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
+ vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
+ for (int i = 0; i < 4; ++i) {
+ vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+0))));
+ vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*i+1))));
+ }
+ }
+ } else {
+ for (int i = 0; i < D/32; ++i) {
+ const auto& b8 = dl[i/4];
+ int ii = i%4;
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(b8.d[ii]));
+ vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+0))));
+ vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)b8.qs+2*ii+1))));
+ }
+ }
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + i/8;
+ int ii = (i/2)%4;
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d[ii]));
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1))));
+ }
+
+ inline void load_2(int l1, __m512 * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/16);
+ }
+};
+
+template <int D, int step>
+struct HelperQ40 final : public BaseHelper<step> {
+ static_assert(step == QK4_0);
+ using Base = BaseHelper<step>;
+ HelperQ40(const char * data, int stride) : Base(data, stride) {}
+
+
+ inline void load(int l1, __m512 * vk) const {
+ auto dl = (const block_q4_0 *)Base::lblock(l1);
+ if constexpr (D >= 128) {
+ ggml_half aux[4];
+ __m512 vd[4];
+ for (int ib = 0; ib < D/128; ++ib) {
+ for (int i = 0; i < 4; ++i) {
+ auto& b4 = dl[4*ib+i];
+ aux[i] = b4.d;
+ auto q = _mm_loadu_si128((const __m128i *)b4.qs);
+ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
+ auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+ vk[8*ib+2*i+0] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql));
+ vk[8*ib+2*i+1] = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh));
+ }
+ auto scales4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)aux));
+ auto scales8 = _mm256_insertf128_ps(_mm256_castps128_ps256(scales4), scales4, 1);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales8), scales8, 1);
+ vd[0] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(0, 0, 0, 0));
+ vd[1] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(1, 1, 1, 1));
+ vd[2] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(2, 2, 2, 2));
+ vd[3] = _mm512_shuffle_ps(scales, scales, _MM_SHUFFLE(3, 3, 3, 3));
+ for (int i = 0; i < 4; ++i) {
+ vk[8*ib+2*i+0] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+0]);
+ vk[8*ib+2*i+1] = _mm512_mul_ps(vd[i], vk[8*ib+2*i+1]);
+ }
+ }
+ } else {
+ for (int i = 0; i < D/32; ++i) {
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
+ auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
+ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
+ auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+ vk[2*i+0] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ vk[2*i+1] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+ }
+ }
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dl = (const block_q4_0 *)Base::lblock(l1) + i/2;
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8);
+ auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8);
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+ }
+
+ inline void load_2(int l1, __m512 * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/16);
+ }
+
+ const __m128i mask = _mm_set1_epi8(0xf);
+ const __m128i m8 = _mm_set1_epi8(-8);
+};
+
+template <int D, int step>
+struct HelperQ41 final : public BaseHelper<step> {
+ static_assert(step == QK4_1);
+ using Base = BaseHelper<step>;
+ HelperQ41(const char * data, int stride) : Base(data, stride) {}
+
+
+ inline void load(int l1, __m512 * vk) const {
+ auto dl = (const block_q4_1 *)Base::lblock(l1);
+ for (int i = 0; i < D/32; ++i) {
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].d));
+ auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl[i].m));
+ auto q = _mm_loadu_si128((const __m128i *)dl[i].qs);
+ auto ql = _mm_and_si128(q, mask);
+ auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
+ vk[2*i+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
+ vk[2*i+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
+ }
+ }
+
+ inline void load(int l1, int i, __m512& v1, __m512& v2) const {
+ auto dl = (const block_q4_1 *)Base::lblock(l1) + i/2;
+ auto vd = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->d));
+ auto vm = _mm512_set1_ps(GGML_FP16_TO_FP32(dl->m));
+ auto q = _mm_loadu_si128((const __m128i *)dl->qs);
+ auto ql = _mm_and_si128(q, mask);
+ auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask);
+ v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm);
+ v2 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)), vm);
+ }
+
+ inline void load_2(int l1, __m512 * vk) const {
+ load(l1+0, vk+0);
+ load(l1+1, vk+D/16);
+ }
+
+ const __m128i mask = _mm_set1_epi8(0xf);
+};
+
+
// Some of the methods in FlashAttn have two identical implementations that only differ by
// one version using a loop over the template parameter q_step, while the other using a loop
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
@@ -6098,13 +6298,13 @@ struct FlashAttn {
auto qr = q + m1*stride_q;
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
if (mp[l1+0] != h_inf) {
- auto vsum = _mm512_mul_ps(vk[0], qv[0]);
- for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
}
if (mp[l1+1] != h_inf) {
- auto vsum = _mm512_mul_ps(vk[D/16], qv[0]);
- for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
}
@@ -6119,8 +6319,8 @@ struct FlashAttn {
return;
}
auto qr = q + m1*stride_q;
- auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
- for (int i = 1; i < D/16; ++i) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/16; ++i) {
vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
}
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
@@ -6191,78 +6391,70 @@ struct FlashAttn {
}
}
- template <bool small = is_small_head, class = std::enable_if<small>>
- inline void mult_mask_kq(int stride_k, int stride_q, int stride_m,
- const char * k, const float * q, const char * mask) {
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
+ inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
__m512 qv[D/16];
for (int l1 = 0; l1 < k_step; l1 += 2) {
- auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
- auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
- for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
- for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
+ kh.load_2(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
}
}
}
- template <bool small = is_small_head, class = std::enable_if<!small>>
- inline void mult_mask_kq_l(int stride_k, int stride_q, int stride_m,
- const char * k, const float * q, const char * mask) {
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
+ inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
+ const float * q, const char * mask) {
for (int l1 = 0; l1 < k_step; ++l1) {
- auto kr = (const ggml_half *)(k + l1*stride_k);
- for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
+ kh.load(l1, vk);
for (int m1 = 0; m1 < q_step; ++m1) {
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
}
}
}
- template <bool small = is_small_head, class = std::enable_if<small>>
- inline void mult_mask_kq(int nq, int stride_k, int stride_q, int stride_m,
- const char * k, const float * q, const char * mask) {
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
+ inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
__m512 qv[D/16];
for (int l1 = 0; l1 < k_step; l1 += 2) {
- auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
- auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
- for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
- for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
+ kh.load_2(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
}
}
}
- template <bool small = is_small_head, class = std::enable_if<!small>>
- inline void mult_mask_kq_l(int nq, int stride_k, int stride_q, int stride_m,
- const char * k, const float * q, const char * mask) {
+ template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
+ inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
+ const float * q, const char * mask) {
for (int l1 = 0; l1 < k_step; ++l1) {
- auto kr = (const ggml_half *)(k + l1*stride_k);
- for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
+ kh.load(l1, vk);
for (int m1 = 0; m1 < nq; ++m1) {
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
}
}
}
- inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
+ template <typename KHelper>
+ inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
if constexpr (is_small_head) {
- mult_mask_kq(stride_k, stride_q, stride_m, k, q, mask);
+ mult_mask_kq(kh, stride_q, stride_m, q, mask);
}
else {
- mult_mask_kq_l(stride_k, stride_q, stride_m, k, q, mask);
+ mult_mask_kq_l(kh, stride_q, stride_m, q, mask);
}
for (int j = 0; j < q_step; ++j) {
update_M_S(j);
}
}
- inline void multiply_mask_kq(int nq, int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
+ template <typename KHelper>
+ inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask) {
if constexpr (is_small_head) {
- mult_mask_kq(nq, stride_k, stride_q, stride_m, k, q, mask);
+ mult_mask_kq(nq, kh, stride_q, stride_m, q, mask);
}
else {
- mult_mask_kq_l(nq, stride_k, stride_q, stride_m, k, q, mask);
+ mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask);
}
for (int j = 0; j < nq; ++j) {
update_M_S(j);
@@ -6271,7 +6463,8 @@ struct FlashAttn {
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
- inline void accumulate_qkv(int stride_v, const char * v) {
+ template <typename VHelper>
+ inline void accumulate_qkv(const VHelper& vh) {
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < q_step; ++j) {
if (need_scaling[j] == 2) {
@@ -6286,10 +6479,9 @@ struct FlashAttn {
}
}
}
+ __m512 v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
- auto vr = (const ggml_half *)(v + l1*stride_v);
- auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
- auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
+ vh.load(l1, i, v1, v2);
for (int j = 0; j < q_step; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
@@ -6304,8 +6496,8 @@ struct FlashAttn {
}
}
- template <int Nq = q_step, class = std::enable_if<Nq >= 2>>
- inline void accumulate_qkv(int nq1, int stride_v, const char * v) {
+ template <typename VHelper, int Nq = q_step, class = std::enable_if<Nq >= 2>>
+ inline void accumulate_qkv(int nq1, const VHelper& vh) {
for (int i = 0; i < D/16; i += 2) {
for (int j = 0; j < nq1; ++j) {
if (need_scaling[j] == 2) {
@@ -6320,10 +6512,9 @@ struct FlashAttn {
}
}
}
+ __m512 v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
- auto vr = (const ggml_half *)(v + l1*stride_v);
- auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
- auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
+ vh.load(l1, i, v1, v2);
for (int j = 0; j < nq1; ++j) {
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
@@ -6338,18 +6529,19 @@ struct FlashAttn {
}
}
- void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
- const char * k, const float * q, const char * mask, const char * v, float * qkv) {
+ template <typename KHelper, typename VHelper>
+ void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float * qkv) {
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
init_qstep();
- auto kr = k;
- auto vr = v;
+ kh.reset_block();
+ vh.reset_block();
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr);
- accumulate_qkv(stride_v, vr);
- kr += k_step*stride_k;
- vr += k_step*stride_v;
+ multiply_mask_kq(kh, stride_q, stride_m, q, mr);
+ accumulate_qkv(vh);
+ kh.next_block();
+ vh.next_block();
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(stride_qkv, qkv);
@@ -6361,14 +6553,14 @@ struct FlashAttn {
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
init_qstep();
- auto kr = k;
- auto vr = v;
+ kh.reset_block();
+ vh.reset_block();
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr);
- accumulate_qkv(n_left, stride_v, vr);
- kr += k_step*stride_k;
- vr += k_step*stride_v;
+ multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr);
+ accumulate_qkv(n_left, vh);
+ kh.next_block();
+ vh.next_block();
mr += k_step*sizeof(ggml_half);
}
normalize_and_store(n_left, stride_qkv, qkv);
@@ -6405,24 +6597,82 @@ struct FlashAttn {
}
};
-template <int D, int q_step, int k_step>
-inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
- const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv) {
+template <int D, int q_step, int k_step, typename KHelper, typename VHelper>
+inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
+ const float * q, const char * mask, float scale, float softcap, float * qkv) {
if (nq1 >= q_step) {
FlashAttn<D, q_step, k_step> fa(scale, softcap);
- fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
- (const char *)k, q, (const char *)mask, (const char *)v, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttn<D, 1, k_step> fa(scale, softcap);
- fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
- (const char *)k, q, (const char *)mask, (const char *)v, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ }
+}
+
+template <int D, int q_step, int k_step, typename KHelper>
+inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
+ int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * v, const char * mask,
+ float scale, float softcap, float * qkv) {
+
+ switch (type_v) {
+ case GGML_TYPE_F16: {
+ HelperF16<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q8_0: {
+ HelperQ80<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q4_0: {
+ HelperQ40<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q4_1: {
+ HelperQ41<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+ default: break;
}
}
+
+template <int D, int q_step, int k_step>
+inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
+ int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
+ const float * q, const char * k, const char * v, const char * mask,
+ float scale, float softcap, float * qkv) {
+
+ switch (type_k) {
+ case GGML_TYPE_F16: {
+ HelperF16<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q8_0: {
+ HelperQ80<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q4_0: {
+ HelperQ40<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
+ case GGML_TYPE_Q4_1: {
+ HelperQ41<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
+ default: break;
+ }
+
+}
+
+inline bool flash_attn_is_supported(ggml_type type) {
+ return type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1;
+}
}
-bool iqk_flash_attn_noalibi(int D, // head size
+bool iqk_flash_attn_noalibi(int int_type_k, // type of k
+ int int_type_v, // type of v
+ int D, // head size
int nq1, // number of columns in q
int nk1, // number of rows in k
int stride_q, // distance between q columns in bytes
@@ -6438,6 +6688,9 @@ bool iqk_flash_attn_noalibi(int D, // head size
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv) { // v*softmax(scale*(k*q))
+ auto type_k = ggml_type(int_type_k);
+ auto type_v = ggml_type(int_type_v);
+ if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
auto ck = (const char *)k;
@@ -6448,19 +6701,19 @@ bool iqk_flash_attn_noalibi(int D, // head size
switch (D) {
case 64:
- iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 4, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 8, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
@@ -6470,7 +6723,9 @@ bool iqk_flash_attn_noalibi(int D, // head size
#else
// TODO
-bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
+bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
+ [[maybe_unused]] int int_type_v, // type of v
+ [[maybe_unused]] int D, // head size
[[maybe_unused]] int nq, // number of columns in q
[[maybe_unused]] int nk, // number of rows in k
[[maybe_unused]] int stride_q, // distance between q columns in bytes
@@ -6501,7 +6756,9 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
return false;
}
-bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
+bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
+ [[maybe_unused]] int int_type_v, // type of v
+ [[maybe_unused]] int D, // head size
[[maybe_unused]] int nq, // number of columns in q
[[maybe_unused]] int nk, // number of rows in k
[[maybe_unused]] int stride_q, // distance between q columns in bytes
diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h
index de50bc6b..6e27c614 100644
--- a/ggml/src/iqk/iqk_mul_mat.h
+++ b/ggml/src/iqk/iqk_mul_mat.h
@@ -21,7 +21,9 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
-bool iqk_flash_attn_noalibi(int D, // head size
+bool iqk_flash_attn_noalibi(int type_k, // type of k
+ int type_v, // type of v
+ int D, // head size
int nq, // number of columns in q
int nk, // number of rows in k
int stride_q, // distance between q columns in bytes