summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/perplexity/perplexity.cpp6
-rw-r--r--ggml/src/ggml-cuda.cu2
-rw-r--r--ggml/src/ggml-cuda/binbcast.cu40
-rw-r--r--ggml/src/ggml-cuda/concat.cu135
-rw-r--r--src/llama.cpp120
5 files changed, 181 insertions, 122 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 372684f0..95aedce6 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -166,7 +166,7 @@ static void process_logits(
break;
}
lock.unlock();
- const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
+ const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]);
const double v = -results.log_softmax;
local_nll += v;
local_nll2 += v*v;
@@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
break;
}
lock.unlock();
- const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
+ const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
}
@@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
- logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
+ logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab);
}
}
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 1bb869c3..58a44cf7 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3354,7 +3354,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
return false;
}
- if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
+ if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index 5abbd43c..a2508350 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -248,17 +248,35 @@ static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
- op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
- } else {
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
- ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ //GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ if (src1->type == GGML_TYPE_F32) {
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+ }
+ else if (src1->type == GGML_TYPE_F16) {
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+ }
+ else {
GGML_ABORT("fatal error");
}
}
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
index 4bde6d69..b40617f6 100644
--- a/ggml/src/ggml-cuda/concat.cu
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -1,7 +1,7 @@
#include "concat.cuh"
// contiguous kernels
-static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float *
}
}
-static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00,
+ int64_t nb02, int64_t nb12, int64_t nb2) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * nb2;
+
+ if (nidx < ne00) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * nb02;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ (nidx - ne00) +
+ blockIdx.y * (ne0 - ne00) +
+ blockIdx.z * nb12;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
}
}
-static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@@ -81,9 +109,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+ if (dim == 0 && ne1 >= 65536) {
+ int64_t nstep = (ne1 + 32767)/32768;
+ for (int64_t istep = 0; istep < nstep; ++istep) {
+ int64_t i1 = 32768*istep;
+ int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1;
+ dim3 gridDim(num_blocks, n1, ne2);
+ const float * xi = x + i1*ne00;
+ const float * yi = y + i1*(ne0 - ne00);
+ float * dst_i = dst + i1*ne0;
+ concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
+ }
+ return;
+ }
dim3 gridDim(num_blocks, ne1, ne2);
if (dim == 0) {
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+ //concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1);
return;
}
if (dim == 1) {
@@ -150,52 +192,77 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
+
cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0];
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
+ (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+ CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream));
+ return;
+ }
+
+ if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
+ src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+ //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
+ // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
+ // GGML_ABORT("fatal error");
+ //}
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(
+ src0_d + i3 * (src0->nb[3] / 4),
+ src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * ( dst->nb[3] / 4),
+ src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
+ dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
+ }
+ } else {
+ dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+ concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+ (const char *)src0->data,
+ (const char *)src1->data,
+ ( char *)dst->data,
+ src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
+ sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3],
+ src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
+ sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3],
+ dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
+ sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim);
+ }
+ return;
+ }
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+ //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
+ // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
+ // GGML_ABORT("fatal error");
+ //}
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
- if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
- const size_t size0 = ggml_nbytes(src0);
- const size_t size1 = ggml_nbytes(src1);
- CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
- CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
- } else {
- for (int i3 = 0; i3 < dst->ne[3]; i3++) {
- concat_f32_cuda(
- src0_d + i3 * (src0->nb[3] / 4),
- src1_d + i3 * (src1->nb[3] / 4),
- dst_d + i3 * ( dst->nb[3] / 4),
- src0->ne[0], src0->ne[1], src0->ne[2],
- dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
- }
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(
+ src0_d + i3 * (src0->nb[3] / 4),
+ src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * ( dst->nb[3] / 4),
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
-
- //if (dim != 3) {
- // for (int i3 = 0; i3 < dst->ne[3]; i3++) {
- // concat_f32_cuda(
- // src0_d + i3 * (src0->nb[3] / 4),
- // src1_d + i3 * (src1->nb[3] / 4),
- // dst_d + i3 * ( dst->nb[3] / 4),
- // src0->ne[0], src0->ne[1], src0->ne[2],
- // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
- // }
- //} else {
- // const size_t size0 = ggml_nbytes(src0);
- // const size_t size1 = ggml_nbytes(src1);
-
- // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
- // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
- //}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
diff --git a/src/llama.cpp b/src/llama.cpp
index 34934a15..605e5d36 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13755,31 +13755,52 @@ struct llm_build_context {
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
+ auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
- ggml_tensor * k;
- ggml_tensor * v;
+ auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
+ int n_max_head = n_head;
+ if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
+ while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
+ n_max_head /= 2; kv_f32_size /= 2;
+ }
+ }
+ GGML_ASSERT(n_head % n_max_head == 0);
+
+ auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
+
+ auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
+ kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
+
+ ggml_tensor repeater;
+ repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
+ auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
+ cb(k_rope, "k_rope", il);
- // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend.
- // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not
- // be taken even if no layers were offloaded to the GPU.
- if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) {
+ auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+ cb(q, "q_concat", il);
+
+ ggml_build_forward_expand(gf, q);
- auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
+ for (int iter = 0; iter < n_head/n_max_head; ++iter) {
- auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
+ auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
+ model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
+
+ auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
- auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head,
+ ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
- v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
+ auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
- auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
+ ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
@@ -13789,74 +13810,27 @@ struct llm_build_context {
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
- auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
- kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
-
- ggml_tensor repeater;
- repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
- auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
- cb(k_rope, "k_rope", il);
-
- k = ggml_concat(ctx0, k_nope, k_rope, 0);
+ auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
- }
- else {
- // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not
- // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
- // multiplication, which *must* be f32.
- auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
- auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
- cb(kv_cache_view_f32, "kv_cache_view_f32", il);
-
- // The no- and rotational position encoding portions of the KV cache
- auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
- auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
- kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
-
- auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
- cb(kv_f32, "kv_f32", il);
-
- auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
- cb(k_nope_f32, "k_nope_f32", il);
- ggml_tensor repeater;
- repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
- auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
- cb(k_rope_f32, "k_rope_f32", il);
+ auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
+ q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
- auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
- cb(k_f32, "k_f32", il);
-
- k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
- cb(k, "k", il);
-
- auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
- cb(v_f32, "v_f32", il);
-
- v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
- cb(v, "v", il);
- }
-
- auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
- q = ggml_permute(ctx0, q, 0, 2, 1, 3);
- cb(q, "q_concat", il);
+ kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
+ if (q->ne[1] <= 8) {
+ ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
+ }
+ cb(kqv, "kqv", il);
- ggml_build_forward_expand(gf, q);
+ if (iter == 0) {
+ cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
+ } else {
+ cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
+ }
- kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
- if (q->ne[1] <= 8) {
- ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
- cb(kqv, "kqv", il);
-
- cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
}
else {