summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-03 15:17:51 +0200
committerGitHub <noreply@github.com>2025-03-03 15:17:51 +0200
commita87e54db6ec2409284a55f029d4abe9e50990064 (patch)
tree920bb8ce4fbd35e54bda3b61a86d0f87c2ac0ede /ggml/src/ggml.c
parenta89adaa78f505675be7be6180f419b4b0158c15a (diff)
Flash MLA (CPU only) (#240)
* FlashMLA - it finally works (on the CPU) * FlashMLA: allow for f16 and bf16 cache in addition to q8_0 * It works with ggml FA, not with iqk FA * WIP * FlashMLA: it now works with iqk I had forgotten to divide the Q stride by sizeof(float) and that's why, very cobfusingly, it was working for TG but not for PP. * WIP * FlashMLA: that should be it for now --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c6
1 files changed, 4 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 31fbc57e..46e1a548 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -10451,7 +10451,7 @@ static void ggml_compute_forward_dup_bytes(
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
- const size_t rs = ne00 * type_size;
+ const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -17871,6 +17871,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
#if GGML_USE_IQK_MULMAT
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
+ //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
+ // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
// I keep changing my mind what is the best strategy to split the threads when processing
// multiple heads. This is my current thinking, the commented out code below was the previous.
int ntg = nth/simple_gcd(neq2*neq3, nth);
@@ -17906,8 +17908,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
return;
IQK_Flash_Attn_NotAvailable:;
+ printf("iqk_flash was rejected\n");
}
-
#endif
const uint32_t n_head = neq2;