summaryrefslogtreecommitdiff
path: root/ggml.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.h')
-rw-r--r--ggml.h20
1 files changed, 20 insertions, 0 deletions
diff --git a/ggml.h b/ggml.h
index 86e5a8dc..a1179597 100644
--- a/ggml.h
+++ b/ggml.h
@@ -475,6 +475,7 @@ extern "C" {
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
+ GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
@@ -1722,6 +1723,25 @@ extern "C" {
struct ggml_tensor * v,
bool masked);
+#define GGML_KQ_MASK_PAD 32
+
+ // q: [n_embd, n_batch, n_head, 1]
+ // k: [n_embd, n_kv, n_head_kv, 1]
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
+ GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale);
+
+ GGML_API void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec);
+
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
struct ggml_tensor * q,