diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 61 |
1 files changed, 35 insertions, 26 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3867cf00..7b631177 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } + // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] } + // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] } + // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, softcap }; @@ -17436,10 +17440,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t Dk = nek0; + const int64_t Dv = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == Dv); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -17447,12 +17452,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == Dk); + GGML_ASSERT(nek0 == Dk); + GGML_ASSERT(nev0 == Dv); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(nev0 == Dv); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( int iq1 = (ith%ntg)*neq1g; int this_neq1 = MIN(neq1g, neq1-iq1); if (!iqk_flash_attn_noalibi(k->type, v->type, - D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), @@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + const int64_t Dkv = MAX(Dk, Dv); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:; float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, Dkv*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:; const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, Dk); // online softmax / attention // loop over n_kv and n_head_kv @@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:; float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1); s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask @@ -17610,14 +17617,14 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(Dv, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -17625,30 +17632,30 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(Dv, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, Dv); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(Dv, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < Dv; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(Dv, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t Dk = node->src[0]->ne[0]; + const int64_t Dv = node->src[2]->ne[0]; + const int64_t D = MAX(Dk, Dv); - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread } break; case GGML_OP_FLASH_ATTN_BACK: { |