summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
authorXuan Son Nguyen <thichthat@gmail.com>2024-04-28 17:36:18 +0200
committerGitHub <noreply@github.com>2024-04-28 17:36:18 +0200
commit7bb36ccf91b8a2e92b182dd75624f1fd7cb205ac (patch)
treeab92b14895245a23730553dc06af68e75995c69c /gguf-py
parentce023f6f2ff34fbe840e32e65d443d2fed7393de (diff)
gguf : enforce that tensor names are unique (#6905)
* not allow adding duplicated tensor name * no duplicated tensor while reading gguf * typo * throw exception inside llama_model_loader Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/gguf_reader.py8
-rw-r--r--gguf-py/gguf/gguf_writer.py5
2 files changed, 12 insertions, 1 deletions
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py
index 33afac55..48ef6d4a 100644
--- a/gguf-py/gguf/gguf_reader.py
+++ b/gguf-py/gguf/gguf_reader.py
@@ -234,8 +234,14 @@ class GGUFReader:
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
tensors = []
+ tensor_names = set() # keep track of name to prevent duplicated tensors
for field in fields:
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
+ # check if there's any tensor having same name already in the list
+ tensor_name = str(bytes(name_data), encoding = 'utf-8')
+ if tensor_name in tensor_names:
+ raise ValueError(f'Found duplicated tensor with name {tensor_name}')
+ tensor_names.add(tensor_name)
ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = np.prod(dims)
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
@@ -267,7 +273,7 @@ class GGUFReader:
item_count = n_bytes
item_type = np.uint8
tensors.append(ReaderTensor(
- name = str(bytes(name_data), encoding = 'utf-8'),
+ name = tensor_name,
tensor_type = ggml_type,
shape = dims,
n_elements = n_elems,
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index e3dbca45..ec44ac9f 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -63,6 +63,7 @@ class GGUFWriter:
self.kv_data_count = 0
self.ti_data = bytearray()
self.ti_data_count = 0
+ self.ti_names = set()
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
@@ -197,6 +198,10 @@ class GGUFWriter:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
+ if name in self.ti_names:
+ raise ValueError(f'Duplicated tensor name {name}')
+ self.ti_names.add(name)
+
encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name