summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_reader.py
diff options
context:
space:
mode:
authorcompilade <git@compilade.net>2024-05-24 21:11:48 -0400
committerGitHub <noreply@github.com>2024-05-25 11:11:48 +1000
commitb83bab15a5d2a1e7807d09613a9b34309d86cfaa (patch)
tree449b4201f8b8929f674fc2ad7654406ba2c50a4b /gguf-py/gguf/gguf_reader.py
parentd041d2ceaaf50e058622d92921b3e680ffa4e9e7 (diff)
gguf-py : fix and simplify quantized shape round-trip (#7483)
* gguf-py : fix and simplify quantized shape round-trip * gguf-py : remove unused import
Diffstat (limited to 'gguf-py/gguf/gguf_reader.py')
-rw-r--r--gguf-py/gguf/gguf_reader.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py
index 21b089f8..e48bc00c 100644
--- a/gguf-py/gguf/gguf_reader.py
+++ b/gguf-py/gguf/gguf_reader.py
@@ -12,6 +12,8 @@ from typing import Any, Literal, NamedTuple, TypeVar, Union
import numpy as np
import numpy.typing as npt
+from .quants import quant_shape_to_byte_shape
+
if __name__ == "__main__":
import sys
from pathlib import Path
@@ -251,6 +253,7 @@ class GGUFReader:
tensor_names.add(tensor_name)
ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = int(np.prod(dims))
+ np_dims = tuple(reversed(dims.tolist()))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
data_offs = int(start_offs + offset_tensor[0])
@@ -279,6 +282,7 @@ class GGUFReader:
else:
item_count = n_bytes
item_type = np.uint8
+ np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
tensors.append(ReaderTensor(
name = tensor_name,
tensor_type = ggml_type,
@@ -286,7 +290,7 @@ class GGUFReader:
n_elements = n_elems,
n_bytes = n_bytes,
data_offset = data_offs,
- data = self._get(data_offs, item_type, item_count),
+ data = self._get(data_offs, item_type, item_count).reshape(np_dims),
field = field,
))
self.tensors = tensors