summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-08-25 14:08:53 +0200
committerGitHub <noreply@github.com>2023-08-25 14:08:53 +0200
commit12e2e33a977af73e75885eeee91c5575a77f4e5f (patch)
tree7fc209639e998fc6347a4b9e50c2e2ab0f20ced2
parent29674ab4e847fcaba60cc6558f0d46d5f74ae279 (diff)
convert.py : export rope freq_base when converting CodeLlama from an HF model (#2773)
-rwxr-xr-xconvert.py34
1 files changed, 18 insertions, 16 deletions
diff --git a/convert.py b/convert.py
index 10276bf6..e58ea46e 100755
--- a/convert.py
+++ b/convert.py
@@ -160,13 +160,14 @@ class Params:
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path))
- n_vocab = config["vocab_size"]
- n_embd = config["hidden_size"]
- n_layer = config["num_hidden_layers"]
- n_ff = config["intermediate_size"]
- n_head = config["num_attention_heads"]
- n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
- f_norm_eps = config["rms_norm_eps"]
+ n_vocab = config["vocab_size"]
+ n_embd = config["hidden_size"]
+ n_layer = config["num_hidden_layers"]
+ n_ff = config["intermediate_size"]
+ n_head = config["num_attention_heads"]
+ n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
+ f_norm_eps = config["rms_norm_eps"]
+ f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
n_mult = Params.find_n_mult(n_ff, n_embd)
@@ -179,15 +180,16 @@ class Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
return Params(
- n_vocab = n_vocab,
- n_embd = n_embd,
- n_mult = n_mult,
- n_layer = n_layer,
- n_ctx = n_ctx,
- n_ff = n_ff,
- n_head = n_head,
- n_head_kv = n_head_kv,
- f_norm_eps = f_norm_eps,
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = n_mult,
+ n_layer = n_layer,
+ n_ctx = n_ctx,
+ n_ff = n_ff,
+ n_head = n_head,
+ n_head_kv = n_head_kv,
+ f_norm_eps = f_norm_eps,
+ f_rope_freq_base = f_rope_freq_base,
)
# LLaMA v2 70B params.json