summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py86
1 files changed, 85 insertions, 1 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index e71a96c4..303d0817 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -184,6 +184,8 @@ class Model:
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
+ if model_architecture == "PlamoForCausalLM":
+ return PlamoModel
return Model
def _is_model_safetensors(self) -> bool:
@@ -225,6 +227,8 @@ class Model:
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
+ if arch == "PlamoForCausalLM":
+ return gguf.MODEL_ARCH.PLAMO
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@@ -1002,11 +1006,91 @@ class Phi2Model(Model):
self.gguf_writer.add_add_bos_token(False)
+class PlamoModel(Model):
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+ def set_gguf_parameters(self):
+ hparams = self.hparams
+ block_count = hparams["num_hidden_layers"]
+
+ self.gguf_writer.add_name("PLaMo")
+ self.gguf_writer.add_context_length(4096) # not in config.json
+ self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+ self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+ self.gguf_writer.add_block_count(block_count)
+ self.gguf_writer.add_head_count(hparams["num_attention_heads"])
+ self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
+ self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
+
+ def shuffle_attn_q_weight(self, data_torch):
+ assert data_torch.size() == (5120, 5120)
+ data_torch = data_torch.reshape(8, 5, 128, 5120)
+ data_torch = torch.permute(data_torch, (1, 0, 2, 3))
+ data_torch = torch.reshape(data_torch, (5120, 5120))
+ return data_torch
+
+ def shuffle_attn_output_weight(self, data_torch):
+ assert data_torch.size() == (5120, 5120)
+ data_torch = data_torch.reshape(5120, 8, 5, 128)
+ data_torch = torch.permute(data_torch, (0, 2, 1, 3))
+ data_torch = torch.reshape(data_torch, (5120, 5120))
+ return data_torch
+
+ def write_tensors(self):
+ block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+
+ for name, data_torch in self.get_tensors():
+ if "self_attn.rotary_emb.inv_freq" in name:
+ continue
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ # shuffle for broadcasting of gqa in ggml_mul_mat
+ if new_name.endswith("attn_q.weight"):
+ data_torch = self.shuffle_attn_q_weight(data_torch)
+ elif new_name.endswith("attn_output.weight"):
+ data_torch = self.shuffle_attn_output_weight(data_torch)
+
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.squeeze().numpy()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
###### CONVERSION LOGIC ######
def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
+ parser = argparse.ArgumentParser(
+ description="Convert a huggingface model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",