summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py21
1 files changed, 20 insertions, 1 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index bced1f56..e46a7813 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -77,8 +77,18 @@ class Model:
self.gguf_writer.add_embedding_length(n_embd)
if (n_ff := self.hparams.get("intermediate_size")) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
- if (n_head := self.hparams.get("num_attention_head")) is not None:
+ if (n_head := self.hparams.get("num_attention_heads")) is not None:
self.gguf_writer.add_head_count(n_head)
+ if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
+ self.gguf_writer.add_head_count_kv(n_head_kv)
+
+ if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
+ self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
+ if (n_experts := self.hparams.get("num_local_experts")) is not None:
+ self.gguf_writer.add_expert_count(n_experts)
+ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+ self.gguf_writer.add_expert_used_count(n_experts_used)
+
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
def write_tensors(self):
@@ -170,6 +180,8 @@ class Model:
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
+ if model_architecture == "MixtralForCausalLM":
+ return MixtralModel
return Model
def _is_model_safetensors(self) -> bool:
@@ -207,6 +219,8 @@ class Model:
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
+ if arch == "MixtralForCausalLM":
+ return gguf.MODEL_ARCH.LLAMA
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@@ -837,6 +851,11 @@ class StableLMModel(Model):
self.gguf_writer.add_layer_norm_eps(1e-5)
+class MixtralModel(Model):
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+
class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):