diff options
author | Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> | 2023-09-14 10:32:26 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 19:32:26 +0300 |
commit | e394084166baac09e8ee9a08a4686f907f7e5291 (patch) | |
tree | 7dd33e82154b7fac7efe46262f2f4afda9e726ad | |
parent | 4c8643dd6ea1a163bc5979cb69c1e7ab0975bc93 (diff) |
gguf-py : support identity operation in TensorNameMap (#3095)
Make try_suffixes keyword param optional.
-rw-r--r-- | gguf-py/gguf/gguf.py | 8 | ||||
-rw-r--r-- | gguf-py/pyproject.toml | 2 |
2 files changed, 6 insertions, 4 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index bda13ac0..7f7204ea 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -333,6 +333,7 @@ class TensorNameMap: tensor_name = tensor_names.get(tensor) if tensor_name is None: continue + mapping[tensor_name] = (tensor, tensor_name) for key in keys: mapping[key] = (tensor, tensor_name) for bid in range(n_blocks): @@ -341,11 +342,12 @@ class TensorNameMap: if tensor_name is None: continue tensor_name = tensor_name.format(bid = bid) + mapping[tensor_name] = (tensor, tensor_name) for key in keys: key = key.format(bid = bid) mapping[key] = (tensor, tensor_name) - def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None: + def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key) if result is not None: return result @@ -356,13 +358,13 @@ class TensorNameMap: return (result[0], result[1] + suffix) return None - def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None: + def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: result = self.get_type_and_name(key, try_suffixes = try_suffixes) if result is None: return None return result[1] - def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None: + def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: result = self.get_type_and_name(key, try_suffixes = try_suffixes) if result is None: return None diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 8da60de1..9489ccd6 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.3.2" +version = "0.3.3" description = "Write ML models in GGUF for GGML" authors = ["GGML <ggml@ggml.ai>"] packages = [ |