diff options
author | slaren <slarengh@gmail.com> | 2023-09-29 18:42:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-29 18:42:32 +0200 |
commit | 40e07a60f9ce06e79f3ccd4c903eba300fb31b5e (patch) | |
tree | 1bbac7bdc74e106b4eeb9fc24ba4464fcccb3a4b | |
parent | bc34dd4f5b5a7c10ae3ed85a265ce6f2ed2fab79 (diff) |
llama.cpp : add documentation about rope_freq_base and scale values (#3401)
* llama.cpp : add documentation about rope_freq_base and scale values
* add notice to hot topics
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | llama.h | 10 |
2 files changed, 6 insertions, 5 deletions
@@ -11,6 +11,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics +- ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401) - Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \ **Devs should become familiar with the new API** - Local Falcon 180B inference on Mac Studio @@ -167,18 +167,18 @@ extern "C" { struct llama_context_params { uint32_t seed; // RNG seed, -1 for random - uint32_t n_ctx; // text context - uint32_t n_batch; // prompt processing batch size + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // prompt processing maximum batch size uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency - float rope_freq_scale; // RoPE frequency scaling factor + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache + bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only }; |