diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-13 12:07:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-13 12:07:43 +0200 |
commit | 305fabfc3b694d603fdb05d671dd59e2d4c7d58e (patch) | |
tree | 645b23c154fa8af405f55138f38d264e05faa2ce | |
parent | 3f23ed68f17583a8ee63afd0c214f5b39226226c (diff) |
FlashMLA-2 (CPU): faster and smaller compute buffer size (#253)
* FlashMLA-2: eliminate intermediate f32 tensors
This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.
* FlashMLA-2: enable fast path only on the CPU for now
I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.
* FlashMLA-2: slightly smaller computer buffer size
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-backend.c | 3 | ||||
-rw-r--r-- | ggml/src/ggml.c | 123 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 28 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 8 | ||||
-rw-r--r-- | src/llama.cpp | 127 |
5 files changed, 233 insertions, 56 deletions
diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 0458bd0c..fd538f50 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -843,7 +843,8 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const op->type != GGML_TYPE_IQ1_S && op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: - return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + return true; + //return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: return true; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 88820438..a904464e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12589,6 +12589,43 @@ static void ggml_compute_forward_repeat_f16( } } +static void ggml_compute_forward_repeat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(ggml_can_repeat(src, dst)); + GGML_ASSERT(src->type == dst->type); + GGML_ASSERT(src->nb[0] == ggml_type_size(src->type)); + int64_t src_row_size = ggml_row_size(src->type, src->ne[0]); + GGML_ASSERT((int64_t )dst->nb[1] == src_row_size*dst->ne[0]/src->ne[0]); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + int64_t i03 = i3 % src->ne[3]; + int64_t i02 = i2 % src->ne[2]; + int64_t i01 = i1 % src->ne[1]; + const char * x = (const char *)src->data + i01*src->nb[1] + i02*src->nb[2] + i03*src->nb[3]; + for (int64_t ir = 0; ir < dst->ne[0]/src->ne[0]; ++ir) { + memcpy(y, x, src_row_size); + y += src_row_size; + } + } +} + static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12609,7 +12646,8 @@ static void ggml_compute_forward_repeat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_repeat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -12762,6 +12800,44 @@ static void ggml_compute_forward_concat_f32( } } +static void ggml_compute_forward_concat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + // Let's do it for dim = 0 only for now + GGML_ASSERT(dim == 0); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + int64_t src0_row_size = ggml_row_size(src0->type, src0->ne[0]); + int64_t src1_row_size = ggml_row_size(src1->type, src1->ne[0]); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + const char * x0 = (const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3]; + const char * x1 = (const char *)src1->data + i1*src1->nb[1] + i2*src1->nb[2] + i3*src1->nb[3]; + memcpy(y, x0, src0_row_size); + memcpy(y + src0_row_size, x1, src1_row_size); + } + +} + static void ggml_compute_forward_concat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12776,7 +12852,8 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -14302,7 +14379,17 @@ UseGgmlGemm1:; const size_t nbw3 = nbw2*ne12; assert(params->wsize >= ne13*nbw3); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + if (src1->type != GGML_TYPE_F32) { +#if GGML_USE_IQK_MULMAT + char * work_buffer = wdata + ne13*nbw3 + ith*ne10*sizeof(float); + GGML_ASSERT(params->wsize >= ne13*nbw3 + nth*ne10*sizeof(float)); + iqk_quantize_any(src1->type, vec_dot_type, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + src1->data, wdata, work_buffer, type_traits[src1->type].to_float, from_float, ith, nth); +#else + GGML_ABORT("fatal error"); +#endif + } + else { //#ifdef GGML_USE_IQK_MULMAT // int ts = type_traits[vec_dot_type].type_size; @@ -14348,6 +14435,7 @@ UseGgmlGemm1:; } } //#endif + } ggml_barrier(params->shared); @@ -16250,28 +16338,28 @@ static void ggml_compute_forward_soft_max_f32( } } -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// //printf("p[%d] = %f\n", i, p[i]); +// assert(!isnan(wp[i])); +// } +//#endif float max = -INFINITY; ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); + //assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(nc, dp, sum); -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// assert(!isnan(dp[i])); +// assert(!isinf(dp[i])); +// } +//#endif } } @@ -21498,6 +21586,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + if (node->src[1]->type != GGML_TYPE_F32) { + cur += n_tasks*node->src[1]->ne[0]*sizeof(float); // src1->type -> f32 -> vec_dot_type + } } } break; case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b61ae2db..fb6a5db4 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -185,6 +185,34 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i } +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * x, void * y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth) { + auto type_x = ggml_type(from_type); + GGML_ASSERT(ggml_type_size(type_x) == nb0); + auto type_y = ggml_type(to_type); + auto row_size_y = ggml_row_size(type_y, ne0); + int64_t nrows = ne1*ne2*ne3; + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = nrows_per_thread*ith; + if (first_row >= nrows) return; + int64_t last_row = std::min(first_row + nrows_per_thread, nrows); + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(ne1*ne2); + int64_t i2 = (row - i3*ne1*ne2)/ne1; + int64_t i1 = row - i3*ne1*ne2 - i2*ne1; + const char * cx = (const char *)x + i1*nb1 + i2*nb2 + i3*nb3; + // TODO: special case common types such as f16, q8_0 + // (although the performance gains may be too small to justify the added complexity) + to_float((const void *)cx, (float *)work_buffer, ne0); + auto cy = (char *)y + (i3*ne1*ne2 + i2*ne1 + i1)*row_size_y; + from_float((const float *)work_buffer, (void *)cy, ne0); + } +} + + size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { IQ1BNQuantizer iq1bn; auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 76fbac3b..d447705b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -248,6 +248,14 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor); // So we can re-pack Microsoft's BitNet I2_S quants void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * GGML_RESTRICT x, void * GGML_RESTRICT y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama.cpp b/src/llama.cpp index ba5c5052..cc15cf33 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13630,45 +13630,94 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { - // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not - // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix - // multiplication, which *must* be f32. - auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); - auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); - cb(kv_cache_view_f32, "kv_cache_view_f32", il); - - // The no- and rotational position encoding portions of the KV cache - auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, - kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); - - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); - cb(kv_f32, "kv_f32", il); - - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); - cb(k_nope_f32, "k_nope_f32", il); - - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; - auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); - cb(k_rope_f32, "k_rope_f32", il); - - auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); - cb(k_f32, "k_f32", il); - - auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); - cb(k, "k", il); - - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); - cb(v_f32, "v_f32", il); - - auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); - cb(v, "v", il); + + ggml_tensor * k; + ggml_tensor * v; + + // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend. + // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not + // be taken even if no layers were offloaded to the GPU. + if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) { + + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); + + v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + cb(v, "v", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); + + auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type); + cb(k_nope, "k_nope", il); + + ggml_build_forward_expand(gf, k_nope); + ggml_build_forward_expand(gf, v); + + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; + auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + cb(k_rope, "k_rope", il); + + k = ggml_concat(ctx0, k_nope, k_rope, 0); + cb(k, "k", il); + + ggml_build_forward_expand(gf, k); + } + else { + // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not + // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix + // multiplication, which *must* be f32. + auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); + cb(kv_cache_view_f32, "kv_cache_view_f32", il); + + // The no- and rotational position encoding portions of the KV cache + auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, + kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; + auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + cb(k_rope_f32, "k_rope_f32", il); + + auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + cb(k_f32, "k_f32", il); + + k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); + cb(k, "k", il); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); + + v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + cb(v, "v", il); + } auto q = ggml_concat(ctx0, q_nope, q_rope, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); |