summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py94
1 files changed, 94 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 0d4ea03b..cae1551a 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -209,6 +209,8 @@ class Model:
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
+ if model_architecture == "BertModel":
+ return BertModel
return Model
def _is_model_safetensors(self) -> bool:
@@ -264,6 +266,8 @@ class Model:
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
+ if arch == "BertModel":
+ return gguf.MODEL_ARCH.BERT
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.")
self.post_write_tensors(tensor_map, name, data_torch)
+class BertModel(Model):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.block_count = self.hparams["num_hidden_layers"]
+
+ def set_gguf_parameters(self):
+ # TODO(cebtenzzre): merge with parent class
+ self.gguf_writer.add_name(self.dir_model.name)
+ self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
+ self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+ self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+ self.gguf_writer.add_block_count(self.block_count)
+ self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
+ self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
+ self.gguf_writer.add_causal_attention(False)
+ self.gguf_writer.add_file_type(self.ftype)
+
+ def set_vocab(self):
+ path = self.dir_model
+ added_tokens_path = self.dir_model if self.dir_model.exists() else None
+
+ # use huggingface vocab to get all tokens
+ vocab = HfVocab(path, added_tokens_path)
+ tokens, scores, toktypes = zip(*vocab.all_tokens())
+ assert len(tokens) == vocab.vocab_size
+
+ # we need this to validate the size of the token_type embeddings
+ # though currently we are passing all zeros to the token_type embeddings
+ n_token_types = len(set(toktypes))
+ self.gguf_writer.add_token_type_count(n_token_types)
+
+ # convert to phantom space vocab
+ def phantom(tok, typ):
+ if tok.startswith(b"[") and tok.endswith(b"]"):
+ return tok
+ if tok.startswith(b"##"):
+ return tok[2:]
+ return b"\xe2\x96\x81" + tok
+ tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
+
+ # set up bos and eos tokens (cls and sep)
+ self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
+ self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
+
+ # add vocab to gguf
+ self.gguf_writer.add_tokenizer_model("bert")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_scores(scores)
+ self.gguf_writer.add_token_types(toktypes)
+
+ # handle special tokens
+ special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ def write_tensors(self):
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+ tensors = dict(self.get_tensors())
+ for name, data_torch in tensors.items():
+ # we are only using BERT for embeddings so we don't need the pooling layer
+ if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
+ continue # we don't need these
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ data = data_torch.squeeze().numpy()
+ n_dims = len(data.shape)
+ new_dtype: type[np.floating[Any]]
+
+ if (
+ self.ftype == 1 and name.endswith(".weight") and n_dims == 2
+ and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
+ ):
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ new_dtype = np.float16
+ else:
+ # if f32 desired, convert any float16 to float32
+ new_dtype = np.float32
+
+ print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
+
+ if data.dtype != new_dtype:
+ data = data.astype(new_dtype)
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
###### CONVERSION LOGIC ######