summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c61
1 files changed, 35 insertions, 26 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 3867cf00..7b631177 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
is_node = true;
}
+ // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] }
+ // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] }
+ // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }
// permute(0, 2, 1, 3)
- int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, softcap };
@@ -17436,10 +17440,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;
- const int64_t D = neq0;
- const int64_t N = neq1;
+ const int64_t Dk = nek0;
+ const int64_t Dv = nev0;
+ const int64_t N = neq1;
- GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne0 == Dv);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
@@ -17447,12 +17452,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
- GGML_ASSERT(neq0 == D);
- GGML_ASSERT(nek0 == D);
- GGML_ASSERT(nev0 == D);
+ GGML_ASSERT(neq0 == Dk);
+ GGML_ASSERT(nek0 == Dk);
+ GGML_ASSERT(nev0 == Dv);
GGML_ASSERT(neq1 == N);
- GGML_ASSERT(nev0 == D);
+ GGML_ASSERT(nev0 == Dv);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
int iq1 = (ith%ntg)*neq1g;
int this_neq1 = MIN(neq1g, neq1-iq1);
if (!iqk_flash_attn_noalibi(k->type, v->type,
- D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+ Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
@@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:;
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
+ const int64_t Dkv = MAX(Dk, Dv);
+
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
@@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
- float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
- float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
- ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
- ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+ float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+ float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
- memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+ memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t));
} else {
- memset(VKQ32, 0, D*sizeof(float));
+ memset(VKQ32, 0, Dkv*sizeof(float));
}
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:;
const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
- q_to_vec_dot(pq, Q_q, D);
+ q_to_vec_dot(pq, Q_q, Dk);
// online softmax / attention
// loop over n_kv and n_head_kv
@@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:;
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
- kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+ kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1);
s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
@@ -17610,14 +17617,14 @@ IQK_Flash_Attn_NotAvailable:;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
- ggml_vec_scale_f16(D, VKQ16, ms);
+ ggml_vec_scale_f16(Dv, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
// V += v*expf(s - M)
- ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+ ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@@ -17625,30 +17632,30 @@ IQK_Flash_Attn_NotAvailable:;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
- ggml_vec_scale_f32(D, VKQ32, ms);
+ ggml_vec_scale_f32(Dv, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
- v_to_float(v_data, V32, D);
+ v_to_float(v_data, V32, Dv);
// V += v*expf(s - M)
- ggml_vec_mad_f32(D, VKQ32, V32, vs);
+ ggml_vec_mad_f32(Dv, VKQ32, V32, vs);
}
S = S*ms + vs; // scale and increment sum with partial sum
}
if (v->type == GGML_TYPE_F16) {
- for (int64_t d = 0; d < D; ++d) {
+ for (int64_t d = 0; d < Dv; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}
// V /= S
const float S_inv = 1.0f/S;
- ggml_vec_scale_f32(D, VKQ32, S_inv);
+ ggml_vec_scale_f32(Dv, VKQ32, S_inv);
// dst indices
const int i1 = iq1;
@@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
- const int64_t ne00 = node->src[0]->ne[0]; // D
+ const int64_t Dk = node->src[0]->ne[0];
+ const int64_t Dv = node->src[2]->ne[0];
+ const int64_t D = MAX(Dk, Dv);
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+ cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
} break;
case GGML_OP_FLASH_ATTN_BACK:
{