diff options
Diffstat (limited to 'gguf-py/gguf/gguf.py')
-rw-r--r-- | gguf-py/gguf/gguf.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index bda13ac0..7f7204ea 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -333,6 +333,7 @@ class TensorNameMap: tensor_name = tensor_names.get(tensor) if tensor_name is None: continue + mapping[tensor_name] = (tensor, tensor_name) for key in keys: mapping[key] = (tensor, tensor_name) for bid in range(n_blocks): @@ -341,11 +342,12 @@ class TensorNameMap: if tensor_name is None: continue tensor_name = tensor_name.format(bid = bid) + mapping[tensor_name] = (tensor, tensor_name) for key in keys: key = key.format(bid = bid) mapping[key] = (tensor, tensor_name) - def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None: + def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key) if result is not None: return result @@ -356,13 +358,13 @@ class TensorNameMap: return (result[0], result[1] + suffix) return None - def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None: + def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: result = self.get_type_and_name(key, try_suffixes = try_suffixes) if result is None: return None return result[1] - def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None: + def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: result = self.get_type_and_name(key, try_suffixes = try_suffixes) if result is None: return None |