summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-08-24 21:10:39 +0200
committerGitHub <noreply@github.com>2023-08-24 21:10:39 +0200
commitd0f77b1353fc820d1ff1e6b87bc6bedde315938d (patch)
tree859fc803bdf8b40fa6bc9ea8912601d368541783
parent0d3094f0c742ce61f84feb6e4f0b59beee6194d7 (diff)
convert.py : try to determine n_ctx automatically for CodeLlama (#2770)
-rwxr-xr-xconvert.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/convert.py b/convert.py
index f335d008..10276bf6 100755
--- a/convert.py
+++ b/convert.py
@@ -200,13 +200,23 @@ class Params:
n_embd = config["dim"]
n_layer = config["n_layers"]
n_mult = config["multiple_of"]
- n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
n_ff = -1
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
f_norm_eps = config["norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
+ # hack to determine LLaMA v1 vs v2 vs CodeLlama
+ if f_rope_freq_base and f_rope_freq_base == 1000000:
+ # CodeLlama
+ n_ctx = 16384
+ elif config["norm_eps"] == 1e-05:
+ # LLaMA v2
+ n_ctx = 4096
+ else:
+ # LLaMA v1
+ n_ctx = 2048
+
if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]