summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gguf-py/gguf/gguf.py8
-rw-r--r--gguf-py/pyproject.toml2
2 files changed, 6 insertions, 4 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py
index bda13ac0..7f7204ea 100644
--- a/gguf-py/gguf/gguf.py
+++ b/gguf-py/gguf/gguf.py
@@ -333,6 +333,7 @@ class TensorNameMap:
tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
+ mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
mapping[key] = (tensor, tensor_name)
for bid in range(n_blocks):
@@ -341,11 +342,12 @@ class TensorNameMap:
if tensor_name is None:
continue
tensor_name = tensor_name.format(bid = bid)
+ mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
- def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
+ def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
if result is not None:
return result
@@ -356,13 +358,13 @@ class TensorNameMap:
return (result[0], result[1] + suffix)
return None
- def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
+ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
- def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
+ def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml
index 8da60de1..9489ccd6 100644
--- a/gguf-py/pyproject.toml
+++ b/gguf-py/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
-version = "0.3.2"
+version = "0.3.3"
description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [