summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c6
1 files changed, 3 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 771bc8ca..45fddca5 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -16150,8 +16150,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
#if GGML_USE_IQK_MULMAT
- if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
- mask && mask->type == GGML_TYPE_F16) {
+ if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
int64_t work_per_slice = D*nek1*neq1;
int ntg = 1;
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
@@ -16165,7 +16164,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1/ntg;
- if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+ if (!iqk_flash_attn_noalibi(k->type, v->type,
+ D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),