summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-12-07 13:03:17 +0200
committerGitHub <noreply@github.com>2023-12-07 13:03:17 +0200
commitbcc0eb4591bec5ec02fad3f2bdcb1b265052ea56 (patch)
tree5082f49b7cb13d8e4f08c14ecf436606a1ae2ff8 /common
parent81bc9214a389362010f7a57f4cbc30e5f83a2d28 (diff)
llama : per-layer KV cache + quantum K cache (#4309)
* per-layer KV * remove unnecessary copies * less code duplication, offload k and v separately * llama : offload KV cache per-layer * llama : offload K shift tensors * llama : offload for rest of the model arches * llama : enable offload debug temporarily * llama : keep the KV related layers on the device * llama : remove mirrors, perform Device -> Host when partial offload * common : add command-line arg to disable KV cache offloading * llama : update session save/load * llama : support quantum K cache (#4312) * llama : support quantum K cache (wip) * metal : add F32 -> Q8_0 copy kernel * cuda : add F32 -> Q8_0 copy kernel ggml-ci * cuda : use mmv kernel for quantum cache ops * llama : pass KV cache type through API * llama : fix build ggml-ci * metal : add F32 -> Q4_0 copy kernel * metal : add F32 -> Q4_1 copy kernel * cuda : wip * cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels * llama-bench : support type_k/type_v * metal : use mm kernel only for quantum KV cache * cuda : add comment * llama : remove memory_f16 and kv_f16 flags --------- Co-authored-by: slaren <slarengh@gmail.com> * readme : add API change notice --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp45
-rw-r--r--common/common.h7
2 files changed, 44 insertions, 8 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 4e823c52..4a61ae59 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
- } else if (arg == "--memory-f32") {
- params.memory_f16 = false;
} else if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
@@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
params.dump_kv_cache = true;
+ } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
+ params.no_kv_offload = true;
+ } else if (arg == "-ctk" || arg == "--cache-type-k") {
+ params.cache_type_k = argv[++i];
+ } else if (arg == "-ctv" || arg == "--cache-type-v") {
+ params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
@@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
- printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
- printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
@@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
+ printf(" -nkvo, --no-kv-offload\n");
+ printf(" disable KV offload\n");
+ printf(" -ctk TYPE, --cache-type-k TYPE\n");
+ printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
+ printf(" -ctv TYPE, --cache-type-v TYPE\n");
+ printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
return mparams;
}
+static ggml_type kv_cache_type_from_str(const std::string & s) {
+ if (s == "f16") {
+ return GGML_TYPE_F16;
+ }
+ if (s == "q8_0") {
+ return GGML_TYPE_Q8_0;
+ }
+ if (s == "q4_0") {
+ return GGML_TYPE_Q4_0;
+ }
+ if (s == "q4_1") {
+ return GGML_TYPE_Q4_1;
+ }
+ if (s == "q5_0") {
+ return GGML_TYPE_Q5_0;
+ }
+ if (s == "q5_1") {
+ return GGML_TYPE_Q5_1;
+ }
+
+ throw std::runtime_error("Invalid cache type: " + s);
+}
+
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
@@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
- cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
@@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
+ cparams.offload_kqv = !params.no_kv_offload;
+
+ cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
+ cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
return cparams;
}
@@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
- fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
diff --git a/common/common.h b/common/common.h
index 02467938..e87ce113 100644
--- a/common/common.h
+++ b/common/common.h
@@ -100,7 +100,6 @@ struct gpt_params {
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
- bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
@@ -125,10 +124,14 @@ struct gpt_params {
bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
+ bool no_kv_offload = false; // disable KV offloading
+
+ std::string cache_type_k = "f16"; // KV cache data type for the K
+ std::string cache_type_v = "f16"; // KV cache data type for the V
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
- std::string image = ""; // path to an image file
+ std::string image = ""; // path to an image file
};
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);