summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 6a2ce187..18337839 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -773,6 +773,148 @@ class BaichuanModel(Model):
return weights[r * n_part:r * n_part + r, ...]
+@Model.register("XverseForCausalLM")
+class XverseModel(Model):
+ model_arch = gguf.MODEL_ARCH.XVERSE
+
+ def set_vocab(self):
+ assert (self.dir_model / "tokenizer.json").is_file()
+ dir_model = self.dir_model
+ hparams = self.hparams
+
+ tokens: list[bytearray] = []
+ toktypes: list[int] = []
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(dir_model)
+ vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+ assert max(tokenizer.vocab.values()) < vocab_size
+
+ reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
+ added_vocab = tokenizer.get_added_vocab()
+
+ for token_id in range(vocab_size):
+ token_text = reverse_vocab[token_id].encode('utf-8')
+ # replace "\x00" to string with length > 0
+ if token_text == b"\x00":
+ toktype = gguf.TokenType.BYTE # special
+ token_text = f"<{token_text}>".encode('utf-8')
+ elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
+ toktype = gguf.TokenType.BYTE # special
+ elif reverse_vocab[token_id] in added_vocab:
+ if tokenizer.added_tokens_decoder[token_id].special:
+ toktype = gguf.TokenType.CONTROL
+ else:
+ toktype = gguf.TokenType.USER_DEFINED
+ else:
+ toktype = gguf.TokenType.NORMAL
+
+ tokens.append(token_text)
+ toktypes.append(toktype)
+
+ self.gguf_writer.add_tokenizer_model("llama")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ def set_gguf_parameters(self):
+ block_count = self.hparams["num_hidden_layers"]
+ head_count = self.hparams["num_attention_heads"]
+ head_count_kv = self.hparams.get("num_key_value_heads", head_count)
+ hf_repo = self.hparams.get("_name_or_path", "")
+
+ ctx_length = 0
+ if "max_sequence_length" in self.hparams:
+ ctx_length = self.hparams["max_sequence_length"]
+ elif "max_position_embeddings" in self.hparams:
+ ctx_length = self.hparams["max_position_embeddings"]
+ elif "model_max_length" in self.hparams:
+ ctx_length = self.hparams["model_max_length"]
+ else:
+ print("gguf: can not find ctx length parameter.")
+ sys.exit()
+
+ self.gguf_writer.add_name(self.dir_model.name)
+ self.gguf_writer.add_source_hf_repo(hf_repo)
+ self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
+ self.gguf_writer.add_context_length(ctx_length)
+ self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+ self.gguf_writer.add_block_count(block_count)
+ self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+ self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+ self.gguf_writer.add_head_count(head_count)
+ self.gguf_writer.add_head_count_kv(head_count_kv)
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+
+ if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
+ if self.hparams["rope_scaling"].get("type") == "linear":
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
+
+ def write_tensors(self):
+ # Collect tensors from generator object
+ model_kv = dict(self.get_tensors())
+ block_count = self.hparams["num_hidden_layers"]
+ head_count = self.hparams["num_attention_heads"]
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ head_count_kv = self.hparams.get("num_key_value_heads", head_count)
+
+ for name, data_torch in model_kv.items():
+ # we don't need these
+ if name.endswith(".rotary_emb.inv_freq"):
+ continue
+
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ # HF models permute some of the tensors, so we need to undo that
+ if name.endswith(("q_proj.weight")):
+ data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
+ if name.endswith(("k_proj.weight")):
+ data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)
+
+ data = data_torch.squeeze().numpy()
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+ self.gguf_writer.add_tensor(new_name, data)
+
+ def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
+ if n_kv_head is not None and n_head != n_kv_head:
+ n_head //= n_kv_head
+
+ return (
+ weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(weights.shape)
+ )
+
+
@Model.register("FalconForCausalLM", "RWForCausalLM")
class FalconModel(Model):
model_arch = gguf.MODEL_ARCH.FALCON