diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-04 07:20:55 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-04 07:20:55 +0300 |
commit | 8c94dcd43350b6bde8f5618f7e0e9f0b400a2ac6 (patch) | |
tree | 8a109e151d38447d1659fd4494cb160a3d585ca3 /ggml/src/ggml.c | |
parent | 9b53c2533fb8c236f319b874c5ff592de8fcd3b4 (diff) |
Zen4 Flash Attnetion 2 (#36)
* Zen4 Flash Attnetion: WIP generalize to other types
Now loading of data from K and V is done via a template parameter,
so this should make it easy to generalize to typ[es other than
F16 for the K and V cache.
* Zen4 Flash Attnetion: it works for q4_0 and q8_0
* Zen4 Flash Attnetion: small q8_0 performance improvement
* Zen4 Flash Attnetion: add q4_1
* Delete unused stuff
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 771bc8ca..45fddca5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16150,8 +16150,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 && - mask && mask->type == GGML_TYPE_F16) { + if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { int64_t work_per_slice = D*nek1*neq1; int ntg = 1; if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; @@ -16165,7 +16164,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( for (int64_t iq2 = 0; iq2 < neq2; iq2++) { if (counter++ % (nth/ntg) == ith/ntg) { int iq1 = (ith%ntg)*neq1/ntg; - if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + if (!iqk_flash_attn_noalibi(k->type, v->type, + D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), |