diff options
author | Julius Arkenberg <arki05@users.noreply.github.com> | 2024-03-23 17:41:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 18:41:53 +0200 |
commit | 476b0251b27fb64c575507024a671e639d675594 (patch) | |
tree | 444579ca4cde7158fb79768bba055dcfae3cdb53 /convert-hf-to-gguf.py | |
parent | 21cad01b6e6e1a96f99391f95e8ea8ae25c8288e (diff) |
llama : add grok-1 support (#6204)
* Add support for Grok model architecture
* Revert convert-hf-to-gguf to default options
* Fixed f_norm_rms_eps bug
* Fix whitespaces
* llama : fix grok rope type
* llama : minor
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-x | convert-hf-to-gguf.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1e49d56c..723ea18e 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -93,31 +93,42 @@ class Model(ABC): if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: self.gguf_writer.add_context_length(n_ctx) + print(f"gguf: context length = {n_ctx}") n_embd = self.find_hparam(["hidden_size", "n_embd"]) self.gguf_writer.add_embedding_length(n_embd) + print(f"gguf: embedding length = {n_embd}") if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: self.gguf_writer.add_feed_forward_length(n_ff) + print(f"gguf: feed forward length = {n_ff}") n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_head_count(n_head) + print(f"gguf: head count = {n_head}") if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: self.gguf_writer.add_head_count_kv(n_head_kv) + print(f"gguf: key-value head count = {n_head_kv}") if (rope_theta := self.hparams.get("rope_theta")) is not None: self.gguf_writer.add_rope_freq_base(rope_theta) + print(f"gguf: rope theta = {rope_theta}") if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + print(f"gguf: rms norm epsilon = {f_rms_eps}") if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) + print(f"gguf: layer norm epsilon = {f_norm_eps}") if (n_experts := self.hparams.get("num_local_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) + print(f"gguf: expert count = {n_experts}") if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) + print(f"gguf: experts used count = {n_experts_used}") self.gguf_writer.add_file_type(self.ftype) + print(f"gguf: file type = {self.ftype}") def write_tensors(self): block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) @@ -1051,6 +1062,21 @@ class MixtralModel(Model): self._set_vocab_sentencepiece() +@Model.register("GrokForCausalLM") +class GrokModel(Model): + model_arch = gguf.MODEL_ARCH.GROK + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_name("Grok") + + @Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM |