diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-18 07:36:42 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-18 07:36:42 +0100 |
commit | dcdfad29f7d2b831f1c84751f00bda14cc359a84 (patch) | |
tree | 7576224579bf2c95734a407e29ac16fabc8efc9d | |
parent | f91b2e38d028c77cc5631295ba0937749e684749 (diff) |
FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260)
* FlashMLA-2: eliminate intermediate f32 tensors
This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.
* FlashMLA-2: enable fast path only on the CPU for now
I did implement the necessary ops on CUDA, but something is
still wrong there, so for now we only use it when running
CPU-only.
* FlashMLA-2: slightly smaller computer buffer size
* Prepare wk_b when loading DeepSeek models (if wk_b is missing)
* Add some comments
* Fix case where wkv_b is quantized with k- or i-quants.
* Fix CUDA
There is an issue with quantized GEMV on CUDA when the left operand
(the matrix) is not contiguous. So, for now, we also create wv_b
during model loading and use that instead of the 3D view of wkv_b.
* FlashMLA-2: avoid conversions to f32 also on CUDA
* Be able to compute for more than 65535 tokens
On CUDA just a quick hack that allows us to cancatenate tensors
with more than 65535 rows along zroth dimension as needed by
FlashMLA-2. Also needed some care in the perplexity tool to
avoid int overflows when evaluating the computed logits.
* Reduce memory usage for FlashMLA-2
Oh, also fix int overflow in the CUDA concat implementation.
It is funny how the llama.cpp 64-bit police has gone (almost) everywhere
and replaced 32-bit ints with 64-bit ints, needed or not,
but hasn't done it where it is actually needed.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/perplexity/perplexity.cpp | 6 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/binbcast.cu | 40 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/concat.cu | 135 | ||||
-rw-r--r-- | src/llama.cpp | 120 |
5 files changed, 181 insertions, 122 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f0..95aedce6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -166,7 +166,7 @@ static void process_logits( break; } lock.unlock(); - const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]); const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; @@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, break; } lock.unlock(); - const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]); local_nll += v; local_nll2 += v*v; } @@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab); } } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1bb869c3..58a44cf7 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3354,7 +3354,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { return false; } - if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) { return false; } if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 5abbd43c..a2508350 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -248,17 +248,35 @@ static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + //GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (src1->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else if (src1->type == GGML_TYPE_F16) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else { GGML_ABORT("fatal error"); } } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 4bde6d69..b40617f6 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,7 +1,7 @@ #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { +// contiguous kernels +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00, + int64_t nb02, int64_t nb12, int64_t nb2) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * nb2; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * nb02; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * nb12; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -81,9 +109,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + if (dim == 0 && ne1 >= 65536) { + int64_t nstep = (ne1 + 32767)/32768; + for (int64_t istep = 0; istep < nstep; ++istep) { + int64_t i1 = 32768*istep; + int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1; + dim3 gridDim(num_blocks, n1, ne2); + const float * xi = x + i1*ne00; + const float * yi = y + i1*(ne0 - ne00); + float * dst_i = dst + i1*ne0; + concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); + } + return; + } dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00); + //concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); return; } if (dim == 1) { @@ -150,52 +192,77 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + cudaStream_t stream = ctx.stream(); const int32_t dim = ((int32_t *) dst->op_params)[0]; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); + return; + } + + if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) && + src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], + dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + } + } else { + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + (const char *)src0->data, + (const char *)src1->data, + ( char *)dst->data, + src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3], + sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3], + sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3], + dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3], + sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim); + } + return; + } + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; - if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); - } else { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - } + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } - - //if (dim != 3) { - // for (int i3 = 0; i3 < dst->ne[3]; i3++) { - // concat_f32_cuda( - // src0_d + i3 * (src0->nb[3] / 4), - // src1_d + i3 * (src1->nb[3] / 4), - // dst_d + i3 * ( dst->nb[3] / 4), - // src0->ne[0], src0->ne[1], src0->ne[2], - // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - // } - //} else { - // const size_t size0 = ggml_nbytes(src0); - // const size_t size1 = ggml_nbytes(src1); - - // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); - //} } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( diff --git a/src/llama.cpp b/src/llama.cpp index 34934a15..605e5d36 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13755,31 +13755,52 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); - ggml_tensor * k; - ggml_tensor * v; + auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); + int n_max_head = n_head; + if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { + while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { + n_max_head /= 2; kv_f32_size /= 2; + } + } + GGML_ASSERT(n_head % n_max_head == 0); + + auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; + + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; + auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + cb(k_rope, "k_rope", il); - // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend. - // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not - // be taken even if no layers were offloaded to the GPU. - if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) { + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat", il); + + ggml_build_forward_expand(gf, q); - auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); + for (int iter = 0; iter < n_head/n_max_head; ++iter) { - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head, + model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter); + + auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope); cb(kv_f32, "kv_f32", il); - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); cb(v_f32, "v_f32", il); - v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); cb(v, "v", il); - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); cb(k_nope_f32, "k_nope_f32", il); @@ -13789,74 +13810,27 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_nope); ggml_build_forward_expand(gf, v); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, - kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); - - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; - auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); - cb(k_rope, "k_rope", il); - - k = ggml_concat(ctx0, k_nope, k_rope, 0); + auto k = ggml_concat(ctx0, k_nope, k_rope, 0); cb(k, "k", il); ggml_build_forward_expand(gf, k); - } - else { - // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not - // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix - // multiplication, which *must* be f32. - auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); - auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); - cb(kv_cache_view_f32, "kv_cache_view_f32", il); - - // The no- and rotational position encoding portions of the KV cache - auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, - kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); - - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); - cb(kv_f32, "kv_f32", il); - - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); - cb(k_nope_f32, "k_nope_f32", il); - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; - auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); - cb(k_rope_f32, "k_rope_f32", il); + auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head, + q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter); - auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); - cb(k_f32, "k_f32", il); - - k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); - cb(k, "k", il); - - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); - cb(v_f32, "v_f32", il); - - v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); - cb(v, "v", il); - } - - auto q = ggml_concat(ctx0, q_nope, q_rope, 0); - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_concat", il); + kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + if (q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } + cb(kqv, "kqv", il); - ggml_build_forward_expand(gf, q); + if (iter == 0) { + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens); + } else { + cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0); + } - kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - if (q->ne[1] <= 8) { - ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); } - cb(kqv, "kqv", il); - - cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); } else { |