summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-04-30 12:16:08 +0300
committerGitHub <noreply@github.com>2024-04-30 12:16:08 +0300
commit9c67c2773d4b706cf71d70ecf4aa180b62501960 (patch)
treebe51cbda5b15ae1bb3a465a2551e7dbe6d3101d7 /ggml.c
parent952d03dbead16e4dbdd1d3458486340673cc2465 (diff)
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c375
1 files changed, 360 insertions, 15 deletions
diff --git a/ggml.c b/ggml.c
index cb273061..74ecd592 100644
--- a/ggml.c
+++ b/ggml.c
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1046,7 +1046,7 @@ do { \
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
// so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
+#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1144,7 @@ do { \
#if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif
}
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#endif
+}
+
// xs and vs are byte strides of x and v
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#endif
}
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#endif
+}
+
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"LEAKY_RELU",
"FLASH_ATTN",
+ "FLASH_ATTN_EXT",
"FLASH_FF",
"FLASH_ATTN_BACK",
"SSM_CONV",
@@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"leaky_relu(x)",
"flash_attn(x)",
+ "flash_attn_ext(x)",
"flash_ff(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
@@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4559,6 +4621,8 @@ struct ggml_tensor * ggml_mul_mat(
void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5397,17 +5461,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
if (pos) {
GGML_ASSERT(ggml_is_vector(pos));
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
GGML_ASSERT(pos->ne[0] == a->ne[0]);
}
+ if (pos && mask) {
+ GGML_ASSERT(pos->type == mask->type);
+ }
+
if (max_bias > 0.0f) {
GGML_ASSERT(pos);
}
@@ -6216,6 +6286,59 @@ struct ggml_tensor * ggml_flash_attn(
return result;
}
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale) {
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
+ // TODO: check if vT can be multiplied by (k*qT)
+ if (mask) {
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(mask->ne[2] == 1);
+ GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+ }
+
+ bool is_node = false;
+
+ if (q->grad || k->grad || v->grad) {
+ is_node = true;
+ }
+
+ // permute(0, 2, 1, 3)
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ float params[] = { scale };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_FLASH_ATTN_EXT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = mask;
+
+ return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+ const int32_t prec_i32 = (int32_t) prec;
+
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+}
+
// ggml_flash_ff
struct ggml_tensor * ggml_flash_ff(
@@ -12255,7 +12378,7 @@ static void ggml_compute_forward_soft_max_f32(
GGML_TENSOR_UNARY_OP_LOCALS
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12278,19 +12401,31 @@ static void ggml_compute_forward_soft_max_f32(
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
- float * pos = src2 ? (float *) src2->data : src0->data;
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
- if (mp) {
- ggml_vec_acc_f32(nc, wp, mp);
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += mp_f32[i];
+ }
+ }
}
// ALiBi bias
@@ -12298,8 +12433,14 @@ static void ggml_compute_forward_soft_max_f32(
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
- for (int i = 0; i < nc; i++) {
- wp[i] = wp[i] + slope*pos[i];
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*pos_f32[i];
+ }
}
}
@@ -14569,6 +14710,198 @@ static void ggml_compute_forward_flash_attn(
}
}
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t D = neq0;
+ const int64_t N = neq1;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne2 == N);
+
+ GGML_ASSERT(nbq0 == sizeof(float));
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev0 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nev0 == D);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // broadcast factors
+ const int64_t rk2 = neq2/nek2;
+ const int64_t rk3 = neq3/nek3;
+
+ const int64_t rv2 = neq2/nev2;
+ const int64_t rv3 = neq3/nev3;
+
+ if (params->type == GGML_TASK_TYPE_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float scale = 1.0f;
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
+ // loop over n_batch and n_head
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ float S = 0.0f;
+ float M = -INFINITY;
+
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
+
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
+
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+ // k indices
+ const int ik3 = iq3 / rk3;
+ const int ik2 = iq2 / rk2;
+
+ // v indices
+ const int iv3 = iq3 / rv3;
+ const int iv2 = iq2 / rv2;
+
+ // online softmax / attention
+ // loop over n_kv and n_head_kv
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
+ for (int64_t ic = 0; ic < nek1; ++ic) {
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+ if (mv == -INFINITY) {
+ continue;
+ }
+
+ float s;
+
+ // convert Q to F16 in V32
+ {
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+
+ for (int64_t d = 0; d < D; ++d) {
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
+ }
+ }
+
+ ggml_vec_dot_f16(D,
+ &s, 0,
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ Q16, 0, 1);
+
+ s = s*scale + mv;
+
+ const float Mold = M;
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ M = s;
+ ms = expf(Mold - M);
+
+ // V = V*expf(Mold - M)
+ ggml_vec_scale_f16(D, V16, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+ // V += v*expf(s - M)
+ ggml_vec_mad_f16(D, V16, v16, vs);
+
+ S = S*ms + vs;
+ }
+
+ // V /= S
+ for (int64_t d = 0; d < D; ++d) {
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
+ }
+
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // original
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+ // permute(0, 2, 1, 3)
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
+ }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+ switch (dst->op_params[1]) {
+ case GGML_PREC_DEFAULT:
+ case GGML_PREC_F32:
+ {
+ // uses F32 accumulators
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_flash_ff
static void ggml_compute_forward_flash_ff_f16(
@@ -16376,6 +16709,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, masked, tensor);
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+ } break;
case GGML_OP_FLASH_FF:
{
ggml_compute_forward_flash_ff(params, tensor);
@@ -17388,6 +17725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_FLASH_ATTN:
+ case GGML_OP_FLASH_ATTN_EXT:
{
struct ggml_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18160,6 +18498,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
n_tasks = n_threads;
} break;
case GGML_OP_FLASH_ATTN:
+ case GGML_OP_FLASH_ATTN_EXT:
{
n_tasks = n_threads;
} break;
@@ -18563,6 +18902,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
} break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ const int64_t ne00 = node->src[0]->ne[0]; // D
+
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
+ } break;
case GGML_OP_FLASH_FF:
{
if (node->src[1]->type == GGML_TYPE_F32) {