summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcompilade <git@compilade.net>2024-06-08 22:47:25 -0400
committerGitHub <noreply@github.com>2024-06-09 12:47:25 +1000
commit5795b941827fdec6c1662986de962badff456718 (patch)
tree551e2e6de458a97763af41df5b70433276512d7d
parented9f2521185706481501a5e6d5315397b11802ff (diff)
convert-hf : match model part name prefix and suffix (#7687)
In #7075, to fix the conversion of (some) models using model-00001-of-00001.safetensors instead of model.safetensors for a single model part we simply used the same logic as the part count to get the part names. But this doesn't always work correctly, like when unusual additional model files like consolidated.safetensors in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 are present. This commit matching both the prefix and the suffix of the model part names should fix this problem without breaking any previously-supported upstream models. But according to report by @teleprint-me there is still some persistent problem, but shall do in the meantime.
-rwxr-xr-xconvert-hf-to-gguf.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 0327712d..b38f48ed 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -73,10 +73,10 @@ class Model:
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager
- self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors")
+ self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
- self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
+ self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = Model.load_hparams(self.dir_model)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@@ -335,10 +335,10 @@ class Model:
self.gguf_writer.close()
@staticmethod
- def get_model_part_names(dir_model: Path, suffix: str) -> list[str]:
+ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
part_names: list[str] = []
for filename in os.listdir(dir_model):
- if filename.endswith(suffix):
+ if filename.startswith(prefix) and filename.endswith(suffix):
part_names.append(filename)
part_names.sort()