summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
authorGuoteng <32697156+SolenoidWGT@users.noreply.github.com>2024-02-01 17:19:51 +0800
committerGitHub <noreply@github.com>2024-02-01 11:19:51 +0200
commitce32060198b7e2d6a13a9b8e1e1369e3c295ae2a (patch)
tree546b5a7a327d3f3f1370549331915c4d6373d51d /convert-hf-to-gguf.py
parent1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915 (diff)
llama : support InternLM2 (#5184)
* support InternLM2 inference * add add_space_prefix KV pair
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py152
1 files changed, 152 insertions, 0 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 6ab7f486..4ebab07b 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -203,6 +203,8 @@ class Model:
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
+ if model_architecture == "InternLM2ForCausalLM":
+ return InternLM2Model
return Model
def _is_model_safetensors(self) -> bool:
@@ -254,6 +256,8 @@ class Model:
return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION
+ if arch == "InternLM2ForCausalLM":
+ return gguf.MODEL_ARCH.INTERNLM2
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@@ -1344,6 +1348,154 @@ class CodeShellModel(Model):
self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
+
+class InternLM2Model(Model):
+ def set_vocab(self):
+ # (TODO): Is there a better way?
+ # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
+ # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
+ # recognized as an empty string in C++.
+ from sentencepiece import SentencePieceProcessor
+ from sentencepiece import sentencepiece_model_pb2 as model
+
+ tokenizer_path = self.dir_model / 'tokenizer.model'
+
+ tokens: list[bytes] = []
+ scores: list[float] = []
+ toktypes: list[int] = []
+
+ if not tokenizer_path.is_file():
+ print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
+ sys.exit(1)
+
+ sentencepiece_model = model.ModelProto()
+ sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
+ add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
+
+ tokenizer = SentencePieceProcessor(str(tokenizer_path))
+ vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+
+ for token_id in range(vocab_size):
+ piece = tokenizer.id_to_piece(token_id)
+ text = piece.encode("utf-8")
+ score = tokenizer.get_score(token_id)
+ if text == b"\x00":
+ # (TODO): fixme
+ # Hack here and replace the \x00 characters.
+ print(f"InternLM2 convert token '{text}' to '🐉'!")
+ text = "🐉"
+
+ toktype = SentencePieceTokenTypes.NORMAL
+ if tokenizer.is_unknown(token_id):
+ toktype = SentencePieceTokenTypes.UNKNOWN
+ elif tokenizer.is_control(token_id):
+ toktype = SentencePieceTokenTypes.CONTROL
+ elif tokenizer.is_unused(token_id):
+ toktype = SentencePieceTokenTypes.UNUSED
+ elif tokenizer.is_byte(token_id):
+ toktype = SentencePieceTokenTypes.BYTE
+
+ tokens.append(text)
+ scores.append(score)
+ toktypes.append(toktype)
+
+ added_tokens_file = self.dir_model / 'added_tokens.json'
+ if added_tokens_file.is_file():
+ with open(added_tokens_file, "r", encoding="utf-8") as f:
+ added_tokens_json = json.load(f)
+
+ for key in added_tokens_json:
+ tokens.append(key.encode("utf-8"))
+ scores.append(-1000.0)
+ toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
+
+ self.gguf_writer.add_tokenizer_model("llama")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_scores(scores)
+ self.gguf_writer.add_token_types(toktypes)
+ self.gguf_writer.add_add_space_prefix(add_prefix)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ def set_gguf_parameters(self):
+ self.gguf_writer.add_name("InternLM2")
+ self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
+ self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
+ self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+ self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+ self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
+ self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+ self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
+
+ def post_write_tensors(self, tensor_map, name, data_torch):
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.squeeze().numpy()
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+ self.gguf_writer.add_tensor(new_name, data)
+
+ def write_tensors(self):
+ from einops import rearrange
+
+ num_heads = self.hparams.get("num_attention_heads")
+ num_kv_heads = self.hparams.get("num_key_value_heads")
+ hidden_size = self.hparams.get("hidden_size")
+ q_per_kv = num_heads // num_kv_heads
+ head_dim = hidden_size // num_heads
+ num_groups = num_heads // q_per_kv
+
+ block_count = self.hparams["num_hidden_layers"]
+ model_kv = dict(self.get_tensors())
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
+ for name, data_torch in model_kv.items():
+ # we don't need these
+ if name.endswith(".rotary_emb.inv_freq"):
+ continue
+
+ if re.match(qkv_pattern, name):
+ bid = re.findall(qkv_pattern, name)[0]
+ qkv = data_torch
+ qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
+ q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
+ q = rearrange(q, " o g n i -> o (g n i)").T
+ k = rearrange(k, " o g n i -> o (g n i)").T
+ v = rearrange(v, " o g n i -> o (g n i)").T
+ self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
+ self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
+ self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
+ else:
+ self.post_write_tensors(tensor_map, name, data_torch)
+
+
###### CONVERSION LOGIC ######