summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-11 14:46:30 +0200
committerGitHub <noreply@github.com>2025-02-11 14:46:30 +0200
commit3c98bfb33d149a0d9d3bb91604dd12709721e3cf (patch)
tree6a1e5fc373032bb18a62ec3616625eedf1a9f1f3
parenta366a3d17d8f2de0eb8c3d9eddc7b5840fb5761a (diff)
DeepSeek FA support (CPU only) (#200)
* Adding support for K head size != V head size This is relevant for DeepSeek models. At this point ggml CPU FA works. Now I need to go and change iqk FA to make it work with Dk != Dv. * iqk support for K head size != V head size To not have compilation time explode, just Dk = 192, Dv = 128 for now (DeepSeek) * FA: very slightly faster for nq = 1 (TG) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml.c61
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp278
-rw-r--r--ggml/src/iqk/iqk_mul_mat.h3
-rw-r--r--src/llama.cpp8
4 files changed, 221 insertions, 129 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 3867cf00..7b631177 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
is_node = true;
}
+ // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] }
+ // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] }
+ // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }
// permute(0, 2, 1, 3)
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, softcap };
@@ -17436,10 +17440,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;
- const int64_t D = neq0;
- const int64_t N = neq1;
+ const int64_t Dk = nek0;
+ const int64_t Dv = nev0;
+ const int64_t N = neq1;
- GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne0 == Dv);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
@@ -17447,12 +17452,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
- GGML_ASSERT(neq0 == D);
- GGML_ASSERT(nek0 == D);
- GGML_ASSERT(nev0 == D);
+ GGML_ASSERT(neq0 == Dk);
+ GGML_ASSERT(nek0 == Dk);
+ GGML_ASSERT(nev0 == Dv);
GGML_ASSERT(neq1 == N);
- GGML_ASSERT(nev0 == D);
+ GGML_ASSERT(nev0 == Dv);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
int iq1 = (ith%ntg)*neq1g;
int this_neq1 = MIN(neq1g, neq1-iq1);
if (!iqk_flash_attn_noalibi(k->type, v->type,
- D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+ Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
@@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:;
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
+ const int64_t Dkv = MAX(Dk, Dv);
+
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
@@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
- float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+ float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+ float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
- memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+ memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t));
} else {
- memset(VKQ32, 0, D*sizeof(float));
+ memset(VKQ32, 0, Dkv*sizeof(float));
}
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:;
const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
- q_to_vec_dot(pq, Q_q, D);
+ q_to_vec_dot(pq, Q_q, Dk);
// online softmax / attention
// loop over n_kv and n_head_kv
@@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:;
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
- kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+ kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1);
s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
@@ -17610,14 +17617,14 @@ IQK_Flash_Attn_NotAvailable:;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
- ggml_vec_scale_f16(D, VKQ16, ms);
+ ggml_vec_scale_f16(Dv, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
// V += v*expf(s - M)
- ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+ ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@@ -17625,30 +17632,30 @@ IQK_Flash_Attn_NotAvailable:;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
- ggml_vec_scale_f32(D, VKQ32, ms);
+ ggml_vec_scale_f32(Dv, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
- v_to_float(v_data, V32, D);
+ v_to_float(v_data, V32, Dv);
// V += v*expf(s - M)
- ggml_vec_mad_f32(D, VKQ32, V32, vs);
+ ggml_vec_mad_f32(Dv, VKQ32, V32, vs);
}
S = S*ms + vs; // scale and increment sum with partial sum
}
if (v->type == GGML_TYPE_F16) {
- for (int64_t d = 0; d < D; ++d) {
+ for (int64_t d = 0; d < Dv; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}
// V /= S
const float S_inv = 1.0f/S;
- ggml_vec_scale_f32(D, VKQ32, S_inv);
+ ggml_vec_scale_f32(Dv, VKQ32, S_inv);
// dst indices
const int i1 = iq1;
@@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
- const int64_t ne00 = node->src[0]->ne[0]; // D
+ const int64_t Dk = node->src[0]->ne[0];
+ const int64_t Dv = node->src[2]->ne[0];
+ const int64_t D = MAX(Dk, Dv);
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+ cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index ee0af7e9..3b58495e 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -14879,10 +14879,60 @@ struct FlashQKV {
using qkv_cache_t = float;
#endif
+ template <typename VHelper>
+ inline void accumulate_qkv_1(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
+ F16::Data vq[D/F16::block_size];
+ if (fms.need_scaling[0] == 2) {
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero();
+ } else {
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i);
+ if (fms.need_scaling[0] == 1) {
+ auto vms = F16::set1(fms.vms[0]);
+ for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]);
+ }
+ }
+ //F16::Data v[8];
+ F16::Data v0, v1;
+ for (int l = 0; l < k_step; l += 4) {
+ auto vs0 = F16::set1(fms.cache[l + 0]);
+ auto vs1 = F16::set1(fms.cache[l + 1]);
+ auto vs2 = F16::set1(fms.cache[l + 2]);
+ auto vs3 = F16::set1(fms.cache[l + 3]);
+ //auto vs = F16::set4(fms.cache + l);
+ for (int i = 0; i < D/F16::block_size; i += 2) {
+ vh.load(l+0, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs0);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs0);
+ vh.load(l+1, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs1);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs1);
+ vh.load(l+2, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs2);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs2);
+ vh.load(l+3, i, v0, v1);
+ vq[i+0] = F16::fmadd(vq[i+0], v0, vs3);
+ vq[i+1] = F16::fmadd(vq[i+1], v1, vs3);
+ //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs);
+ //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs);
+ //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs);
+ //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs);
+ //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs);
+ //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs);
+ //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs);
+ //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs);
+ }
+ }
+ for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]);
+ }
+
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
// Hence, for now, we will not handle head sizes of 80 and 112
template <typename VHelper>
inline void accumulate_qkv(const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
+ if constexpr (q_step == 1) {
+ accumulate_qkv_1(vh, fms);
+ return;
+ }
F16::Data v[8];
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
@@ -14924,6 +14974,10 @@ struct FlashQKV {
template <typename VHelper>
inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS<q_step, k_step>& fms) {
+ if (nq1 == 1) {
+ accumulate_qkv_1(vh, fms);
+ return;
+ }
F16::Data v[8];
for (int j = 0; j < nq1; ++j) {
auto R = qkv_cache + D*j;
@@ -15346,13 +15400,13 @@ struct FlashQKfp32 {
}
};
-template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
FlashMS<q_step, k_step>& fms,
- FlashQKV<D, q_step, k_step>& fqkv,
+ FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
#ifdef __aarch64__
- float16_t q_f16[D*q_step];
+ float16_t q_f16[Dk*q_step];
#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
@@ -15365,7 +15419,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#ifdef __aarch64__
- KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms);
+ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
#else
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
#endif
@@ -15391,7 +15445,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#ifdef __aarch64__
- KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms);
+ KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms);
#else
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
#endif
@@ -15404,12 +15458,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
}
}
-template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
+template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
FlashMS<q_step, k_step>& fms,
- FlashQKV<D, q_step, k_step>& fqkv,
+ FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
- typename KHelper::block_q8 q8[q_step*(D/QK8_0)];
+ typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)];
#if FA_TIMING
Perf perf(false);
#endif
@@ -15420,7 +15474,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
fms.init_qstep();
kh.reset_block();
vh.reset_block();
- HelperQ80<D, QK8_0>::convert(q_step, stride_q, q, q8);
+ HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
#if FA_TIMING
perf.accum_nolock(0, t1);
#endif
@@ -15458,7 +15512,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
fms.init_qstep();
kh.reset_block();
vh.reset_block();
- HelperQ80<D, QK8_0>::convert(n_left, stride_q, q, q8);
+ HelperQ80<Dk, QK8_0>::convert(n_left, stride_q, q, q8);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms);
@@ -15484,9 +15538,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
// process template parameter of such functions, but this would result in the compiler generating
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
-template <int D, int q_step, int k_step>
+template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttn {
- static_assert(D%F16::block_size == 0 && D <= 256);
+ static_assert(Dk%F16::block_size == 0 && Dk <= 256);
+ static_assert(Dv%F16::block_size == 0 && Dv <= 256);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -15495,35 +15550,35 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
- if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
- std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
- std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
- compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
+ std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
}
- else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
+ else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
- HelperQ80R4<D, k_step> khr4(nk1, kh);
+ HelperQ80R4<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
- HelperQ80R4<D, k_step> khr4(nk1, kh);
+ HelperQ80R4<Dk, k_step> khr4(nk1, kh);
#endif
- compute_helper_q<D, q_step, k_step, HelperQ80R4<D, k_step>, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else{
- compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
}
} else {
- compute_helper<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
+ compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
}
}
- FlashMS<q_step, k_step> fms;
- FlashQKV<D, q_step, k_step> fqkv;
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<Dv, q_step, k_step> fqkv;
};
@@ -15756,7 +15811,22 @@ struct FlashQKbf16 {
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
#endif
- {
+ if constexpr (q_step == 1) {
+ __m512bh vq[D/32];
+ __m512bh vk[D/32];
+ __m256 sum[8];
+ for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i));
+ for (int l = 0; l < k_step; l += 8) {
+ for (int k = 0; k < 8; ++k) {
+ kh.load(l+k, vk);
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]);
+ sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ }
+ _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum));
+ }
+ }
+ else {
__m512bh qv[D/32];
if constexpr (D <= 128) {
__m512bh vkh[D/4];
@@ -15856,9 +15926,10 @@ struct FlashQKbf16 {
}
};
-template <int D, int q_step, int k_step>
+template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttnBF16 {
- static_assert(D%32 == 0 && D <= 256);
+ static_assert(Dk%32 == 0 && Dk <= 256);
+ static_assert(Dv%32 == 0 && Dv <= 256);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -15867,7 +15938,7 @@ struct FlashAttnBF16 {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
- ggml_bf16_t q_bf16[q_step*D];
+ ggml_bf16_t q_bf16[q_step*Dk];
#if FA_TIMING
Perf perf(false);
#endif
@@ -15878,7 +15949,7 @@ struct FlashAttnBF16 {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
- FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16);
+ FlashQKbf16<Dk, q_step, k_step>::convert(stride_q, q, q_bf16);
#if FA_TIMING
perf.accum_nolock(0, t1);
#endif
@@ -15886,13 +15957,13 @@ struct FlashAttnBF16 {
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#if FA_TIMING
//t1 = Perf::cur_time();
- FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
//perf.accum_nolock(1, t1);
t1 = Perf::cur_time();
fqkv.accumulate_qkv(vh, fms);
perf.accum_nolock(3, t1);
#else
- FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(vh, fms);
#endif
kh.next_block();
@@ -15916,10 +15987,10 @@ struct FlashAttnBF16 {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
- FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
+ FlashQKbf16<Dk, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
+ FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(n_left, vh, fms);
kh.next_block();
vh.next_block();
@@ -15932,72 +16003,72 @@ struct FlashAttnBF16 {
#endif
}
- FlashMS<q_step, k_step> fms;
- FlashQKV<D, q_step, k_step> fqkv;
+ FlashMS<q_step, k_step> fms;
+ FlashQKV<Dv, q_step, k_step> fqkv;
};
#endif
-template <int D, int k_step, typename KHelper, typename VHelper>
+template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
if (nk1 >= 256) { //4096) {
if (nq1 >= 64) {
- FlashAttn<D, 64, k_step> fa(scale, softcap);
+ FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
if (nq1 >= 32) {
- FlashAttn<D, 32, k_step> fa(scale, softcap);
+ FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
if (nq1 >= 16) {
- FlashAttn<D, 16, k_step> fa(scale, softcap);
+ FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
}
if (nq1 >= 8) {
- FlashAttn<D, 8, k_step> fa(scale, softcap);
+ FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
else {
- FlashAttn<D, 1, k_step> fa(scale, softcap);
+ FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
}
#ifdef __AVX512BF16__
-template <int D, int k_step>
+template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv) {
- HelperBF16<D, k_step> kh(k, stride_k);
- HelperBF16<D, k_step> vh(v, stride_v);
+ HelperBF16<Dk, k_step> kh(k, stride_k);
+ HelperBF16<Dv, k_step> vh(v, stride_v);
if (nk1 >= 4096) {
if (nq1 >= 64) {
- FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
+ FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
else if (nq1 >= 16) {
- FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
+ FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
}
if (nq1 >= 8) {
- FlashAttnBF16<D, 8, k_step> fa(scale, softcap);
+ FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
- FlashAttnBF16<D, 1, k_step> fa(scale, softcap);
+ FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
}
#endif
-template <int D, int k_step, typename KHelper>
+template <int Dk, int Dv, int k_step, typename KHelper>
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * v, const char * mask,
@@ -16005,42 +16076,42 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
switch (type_v) {
case GGML_TYPE_F16: {
- HelperF16<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperF16<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#ifdef HAVE_FANCY_SIMD
case GGML_TYPE_BF16: {
- HelperBF16<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperBF16<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#endif
case GGML_TYPE_Q8_0: {
- HelperQ80<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperQ80<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
- HelperQ60<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperQ60<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
- HelperQ40<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperQ40<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_1: {
- HelperQ41<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperQ41<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
- HelperIQ4nl<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ HelperIQ4nl<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
#endif
default: break;
}
}
-template <int D, int k_step>
+template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
@@ -16048,29 +16119,29 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
switch (type_k) {
case GGML_TYPE_F16: {
- HelperF16<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperF16<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q8_0: {
- HelperQ80<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperQ80<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
- HelperQ60<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperQ60<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
- HelperQ40<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperQ40<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_1: {
- HelperQ41<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperQ41<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
- HelperIQ4nl<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ HelperIQ4nl<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
#endif
default: break;
@@ -16094,7 +16165,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
bool iqk_flash_attn_noalibi(int int_type_k, // type of k
int int_type_v, // type of v
- int D, // head size
+ int Dk, // K head size
+ int Dv, // V head size
int nq1, // number of columns in q
int nk1, // number of rows in k
int stride_q, // distance between q columns in bytes
@@ -16114,7 +16186,9 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
auto type_v = ggml_type(int_type_v);
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
- if (D != 64 && D != 96 && D != 128 && D != 256) return false;
+ if (Dk != Dv && Dk != 192 && Dv != 128) return false;
+ if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
+ if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;
auto ck = (const char *)k;
auto cv = (const char *)v;
@@ -16126,30 +16200,34 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
if (type_k == GGML_TYPE_BF16) {
if (nk1%64 == 0) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
- switch (D) {
+ switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 192:
+ iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
return true;
}
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
- switch (D) {
+ switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 192:
+ iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
@@ -16159,41 +16237,45 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
#endif
if (nk1%64 == 0) {
- switch (D) {
+ switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 192:
+ iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
return true;
}
- switch (D) {
+ switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 192:
+ iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h
index 6e27c614..b24dc7b2 100644
--- a/ggml/src/iqk/iqk_mul_mat.h
+++ b/ggml/src/iqk/iqk_mul_mat.h
@@ -23,7 +23,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
bool iqk_flash_attn_noalibi(int type_k, // type of k
int type_v, // type of v
- int D, // head size
+ int Dk, // K head size
+ int Dv, // V head size
int nq, // number of columns in q
int nk, // number of rows in k
int stride_q, // distance between q columns in bytes
diff --git a/src/llama.cpp b/src/llama.cpp
index b2553802..0817c53c 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -17768,10 +17768,10 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
- if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
- LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
- params.flash_attn = false;
- }
+ //if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
+ // LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
+ // params.flash_attn = false;
+ //}
if (params.type_v != GGML_TYPE_F16 && params.type_v != GGML_TYPE_BF16 && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);