diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-12-07 13:03:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-07 13:03:17 +0200 |
commit | bcc0eb4591bec5ec02fad3f2bdcb1b265052ea56 (patch) | |
tree | 5082f49b7cb13d8e4f08c14ecf436606a1ae2ff8 /common/common.cpp | |
parent | 81bc9214a389362010f7a57f4cbc30e5f83a2d28 (diff) |
llama : per-layer KV cache + quantum K cache (#4309)
* per-layer KV
* remove unnecessary copies
* less code duplication, offload k and v separately
* llama : offload KV cache per-layer
* llama : offload K shift tensors
* llama : offload for rest of the model arches
* llama : enable offload debug temporarily
* llama : keep the KV related layers on the device
* llama : remove mirrors, perform Device -> Host when partial offload
* common : add command-line arg to disable KV cache offloading
* llama : update session save/load
* llama : support quantum K cache (#4312)
* llama : support quantum K cache (wip)
* metal : add F32 -> Q8_0 copy kernel
* cuda : add F32 -> Q8_0 copy kernel
ggml-ci
* cuda : use mmv kernel for quantum cache ops
* llama : pass KV cache type through API
* llama : fix build
ggml-ci
* metal : add F32 -> Q4_0 copy kernel
* metal : add F32 -> Q4_1 copy kernel
* cuda : wip
* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels
* llama-bench : support type_k/type_v
* metal : use mm kernel only for quantum KV cache
* cuda : add comment
* llama : remove memory_f16 and kv_f16 flags
---------
Co-authored-by: slaren <slarengh@gmail.com>
* readme : add API change notice
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/common/common.cpp b/common/common.cpp index 4e823c52..4a61ae59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.yarn_beta_slow = std::stof(argv[i]); - } else if (arg == "--memory-f32") { - params.memory_f16 = false; } else if (arg == "--samplers") { if (++i >= argc) { invalid_param = true; @@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.infill = true; } else if (arg == "-dkvc" || arg == "--dump-kv-cache") { params.dump_kv_cache = true; + } else if (arg == "-nkvo" || arg == "--no-kv-offload") { + params.no_kv_offload = true; + } else if (arg == "-ctk" || arg == "--cache-type-k") { + params.cache_type_k = argv[++i]; + } else if (arg == "-ctv" || arg == "--cache-type-v") { + params.cache_type_v = argv[++i]; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { @@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --no-penalize-nl do not penalize newline token\n"); - printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); @@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); printf(" verbose print of the KV cache\n"); + printf(" -nkvo, --no-kv-offload\n"); + printf(" disable KV offload\n"); + printf(" -ctk TYPE, --cache-type-k TYPE\n"); + printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); + printf(" -ctv TYPE, --cache-type-v TYPE\n"); + printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); @@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & return mparams; } +static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + throw std::runtime_error("Invalid cache type: " + s); +} + struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); @@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.mul_mat_q = params.mul_mat_q; cparams.seed = params.seed; - cparams.f16_kv = params.memory_f16; cparams.logits_all = params.logits_all; cparams.embedding = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.offload_kqv = !params.no_kv_offload; + + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); + cparams.type_v = kv_cache_type_from_str(params.cache_type_v); return cparams; } @@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); |