diff options
author | Xiao-Yong Jin <jinxiaoyong@gmail.com> | 2023-07-15 06:34:16 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-15 13:34:16 +0300 |
commit | 6e7cca404748dd4b1a3affd0d1296e37f4ac0a6f (patch) | |
tree | dcbb7be0dbc8da79e0bf54d57a55b4b78b1dd461 /examples/common.h | |
parent | a6803cab946c817fb7aaf2a40b317f5d3e373bd1 (diff) |
llama : add custom RoPE (#2054)
* Implement customizable RoPE
The original RoPE has pre-defined parameters
theta_i = 10000^(ā2(iā1)/d), for i in [1, 2, ..., d/2]
Our customizable RoPE, ggml_rope_custom_inplace, uses
theta_i = scale * base^(ā2(iā1)/d), for i in [1, 2, ..., d/2]
with the default matches the original
scale = 1.0
base = 10000
The new command line arguments
--rope-freq-base
--rope-freq-scale
set the two new RoPE parameter.
Recent researches show changing these two parameters extends the context limit with minimal loss.
1. Extending Context to 8K
kaiokendev
https://kaiokendev.github.io/til#extending-context-to-8k
2. Extending Context Window of Large Language Models via Positional Interpolation
Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian
https://arxiv.org/abs/2306.15595
3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
https://www.reddit.com/user/bloc97
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
For the bold, try adding the following command line parameters to your favorite model:
-c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5
* ggml-metal: fix custom rope
* common: fix argument names in help
* llama: increase MEM_REQ_EVAL for MODEL_3B
It avoids crashing for quantized weights on CPU.
Better ways to calculate the required buffer size would be better.
* llama: make MEM_REQ_EVAL depend on n_ctx
* server: use proper Content-Type in curl examples
Without the header Content-Type: application/json, curl will POST with
Content-Type: application/x-www-form-urlencoded
Though our simple server doesn't care, the httplib.h used has a limit
with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192
With Content-Type: application/json, we can send large json data.
* style : minor fixes, mostly indentations
* ggml : fix asserts
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/common.h')
-rw-r--r-- | examples/common.h | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/examples/common.h b/examples/common.h index 6315df96..f52fef62 100644 --- a/examples/common.h +++ b/examples/common.h @@ -32,6 +32,8 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + float rope_freq_base = 10000.0f; // RoPE base frequency + float rope_freq_scale = 1.0f; // RoPE frequency scaling factor // sampling parameters std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens |