summaryrefslogtreecommitdiff
path: root/src/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llama.cpp')
-rw-r--r--src/llama.cpp49
1 files changed, 40 insertions, 9 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 8c4a966d..0257a0a3 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+ const uint32_t n_head = hparams.n_head(i);
+ const uint32_t n_head_kv = hparams.n_head_kv(i);
+ const uint32_t n_embd_head_k= hparams.n_embd_head_k;
+
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
@@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
#if MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
+ //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#else
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#endif
@@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
n_mla++;
}
else {
- k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ //printf("Creating cache tensors:\n");
+ //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
+ //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
@@ -4002,6 +4010,7 @@ struct llama_model_loader {
case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break;
case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
+ case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break;
case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break;
@@ -4012,6 +4021,7 @@ struct llama_model_loader {
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break;
+ case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break;
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
@@ -4730,6 +4740,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
+ case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
@@ -4746,6 +4757,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4";
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8";
+ case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
@@ -8283,11 +8295,20 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+ const int64_t n_head = hparams.n_head(il);
+ const int64_t n_head_kv = hparams.n_head_kv(il);
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
+
GGML_ASSERT(kv.size == n_ctx);
- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
- (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
- cb(k_cache_view, "k_cache_view", il);
+ //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
+ // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
+ //cb(k_cache_view, "k_cache_view", il);
+
+ auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
+ ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
+ k_row_size, k_row_size*n_head_kv*kv_head);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@@ -8706,7 +8727,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
- ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
@@ -13507,8 +13528,9 @@ struct llm_build_context {
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
- ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
+ auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
+ ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
+ row_size, row_size*kv_head);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank + n_embd_head_qk_rope, n_kv,
@@ -16164,7 +16186,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = GGML_TYPE_IQ5_K;
}
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 &&
- new_type != GGML_TYPE_Q8_K_R8) {
+ new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) {
new_type = GGML_TYPE_Q6_K;
}
}
@@ -16218,6 +16240,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q8_K_R8) {
new_type = GGML_TYPE_Q8_0;
}
+ else if (new_type == GGML_TYPE_Q8_KV_R8) {
+ new_type = GGML_TYPE_Q8_0;
+ }
else if (new_type == GGML_TYPE_IQ2_K_R4) {
new_type = GGML_TYPE_IQ2_K;
}
@@ -16728,6 +16753,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
+ case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break;
@@ -16751,6 +16777,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break;
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break;
+ case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
@@ -17194,6 +17221,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 8;
}
+ else if (new_type == GGML_TYPE_Q8_KV_R8) {
+ if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
+ else chunk_size_multiplier = 8;
+ }
else if (new_type == GGML_TYPE_IQ2_BN_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN;
else chunk_size_multiplier = 4;