diff options
author | Xuan Son Nguyen <thichthat@gmail.com> | 2024-04-28 17:36:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-28 17:36:18 +0200 |
commit | 7bb36ccf91b8a2e92b182dd75624f1fd7cb205ac (patch) | |
tree | ab92b14895245a23730553dc06af68e75995c69c /gguf-py | |
parent | ce023f6f2ff34fbe840e32e65d443d2fed7393de (diff) |
gguf : enforce that tensor names are unique (#6905)
* not allow adding duplicated tensor name
* no duplicated tensor while reading gguf
* typo
* throw exception inside llama_model_loader
Co-authored-by: slaren <slarengh@gmail.com>
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'gguf-py')
-rw-r--r-- | gguf-py/gguf/gguf_reader.py | 8 | ||||
-rw-r--r-- | gguf-py/gguf/gguf_writer.py | 5 |
2 files changed, 12 insertions, 1 deletions
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 33afac55..48ef6d4a 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -234,8 +234,14 @@ class GGUFReader: def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: tensors = [] + tensor_names = set() # keep track of name to prevent duplicated tensors for field in fields: _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts + # check if there's any tensor having same name already in the list + tensor_name = str(bytes(name_data), encoding = 'utf-8') + if tensor_name in tensor_names: + raise ValueError(f'Found duplicated tensor with name {tensor_name}') + tensor_names.add(tensor_name) ggml_type = GGMLQuantizationType(raw_dtype[0]) n_elems = np.prod(dims) block_size, type_size = GGML_QUANT_SIZES[ggml_type] @@ -267,7 +273,7 @@ class GGUFReader: item_count = n_bytes item_type = np.uint8 tensors.append(ReaderTensor( - name = str(bytes(name_data), encoding = 'utf-8'), + name = tensor_name, tensor_type = ggml_type, shape = dims, n_elements = n_elems, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e3dbca45..ec44ac9f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -63,6 +63,7 @@ class GGUFWriter: self.kv_data_count = 0 self.ti_data = bytearray() self.ti_data_count = 0 + self.ti_names = set() self.use_temp_file = use_temp_file self.temp_file = None self.tensors = [] @@ -197,6 +198,10 @@ class GGUFWriter: if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') + if name in self.ti_names: + raise ValueError(f'Duplicated tensor name {name}') + self.ti_names.add(name) + encoded_name = name.encode("utf8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name |