diff options
author | Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> | 2024-03-02 01:00:46 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-01 21:30:46 +0200 |
commit | c29af7e2252d288f2ea58a7d437c1cb7c0abf160 (patch) | |
tree | b17451289ae835cb33f10c79db82a1e91004e225 /convert-hf-to-gguf.py | |
parent | 38d16b142624bdd7c41d9955752b7f7b59c5e048 (diff) |
llama : add StarCoder2 support (#5795)
* Add support for starcoder2
* handle rope type
* skip rope freq and rotary embeddings from being serialized
* resolve comments
* Update llama.cpp
* remove redundant changes
* handle `rope-theta`
* llama : change starcoder2 rope type
* address comment
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-x | convert-hf-to-gguf.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d3e8ec1f..28b92ac3 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -96,9 +96,11 @@ class Model: if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: self.gguf_writer.add_head_count_kv(n_head_kv) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) - if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None: + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) if (n_experts := self.hparams.get("num_local_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) @@ -220,6 +222,8 @@ class Model: return NomicBertModel if model_architecture == "GemmaForCausalLM": return GemmaModel + if model_architecture == "Starcoder2ForCausalLM": + return Model return Model def _is_model_safetensors(self) -> bool: @@ -281,6 +285,8 @@ class Model: return gguf.MODEL_ARCH.NOMIC_BERT if arch == "GemmaForCausalLM": return gguf.MODEL_ARCH.GEMMA + if arch == "Starcoder2ForCausalLM": + return gguf.MODEL_ARCH.STARCODER2 raise NotImplementedError(f'Architecture "{arch}" not supported!') |