diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-12-07 13:03:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-07 13:03:17 +0200 |
commit | bcc0eb4591bec5ec02fad3f2bdcb1b265052ea56 (patch) | |
tree | 5082f49b7cb13d8e4f08c14ecf436606a1ae2ff8 /llama.h | |
parent | 81bc9214a389362010f7a57f4cbc30e5f83a2d28 (diff) |
llama : per-layer KV cache + quantum K cache (#4309)
* per-layer KV
* remove unnecessary copies
* less code duplication, offload k and v separately
* llama : offload KV cache per-layer
* llama : offload K shift tensors
* llama : offload for rest of the model arches
* llama : enable offload debug temporarily
* llama : keep the KV related layers on the device
* llama : remove mirrors, perform Device -> Host when partial offload
* common : add command-line arg to disable KV cache offloading
* llama : update session save/load
* llama : support quantum K cache (#4312)
* llama : support quantum K cache (wip)
* metal : add F32 -> Q8_0 copy kernel
* cuda : add F32 -> Q8_0 copy kernel
ggml-ci
* cuda : use mmv kernel for quantum cache ops
* llama : pass KV cache type through API
* llama : fix build
ggml-ci
* metal : add F32 -> Q4_0 copy kernel
* metal : add F32 -> Q4_1 copy kernel
* cuda : wip
* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels
* llama-bench : support type_k/type_v
* metal : use mm kernel only for quantum KV cache
* cuda : add comment
* llama : remove memory_f16 and kv_f16 flags
---------
Co-authored-by: slaren <slarengh@gmail.com>
* readme : add API change notice
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r-- | llama.h | 13 |
1 files changed, 8 insertions, 5 deletions
@@ -42,7 +42,7 @@ #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 2 +#define LLAMA_SESSION_VERSION 3 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. @@ -211,11 +211,14 @@ extern "C" { float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size + enum ggml_type type_k; // data type for K cache + enum ggml_type type_v; // data type for V cache + // Keep the booleans together to avoid misalignment during copy-by-value. - bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) - bool f16_kv; // use fp16 for KV cache, fp32 otherwise - bool logits_all; // the llama_eval() call computes all logits, not just the last one - bool embedding; // embedding mode only + bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) + bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool embedding; // embedding mode only + bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU }; // model quantization parameters |