diff options
author | compilade <git@compilade.net> | 2024-05-24 21:11:48 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-25 11:11:48 +1000 |
commit | b83bab15a5d2a1e7807d09613a9b34309d86cfaa (patch) | |
tree | 449b4201f8b8929f674fc2ad7654406ba2c50a4b /gguf-py/gguf/gguf_writer.py | |
parent | d041d2ceaaf50e058622d92921b3e680ffa4e9e7 (diff) |
gguf-py : fix and simplify quantized shape round-trip (#7483)
* gguf-py : fix and simplify quantized shape round-trip
* gguf-py : remove unused import
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r-- | gguf-py/gguf/gguf_writer.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8b41b54e..c194dd5d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -13,7 +13,6 @@ from string import ascii_letters, digits import numpy as np from .constants import ( - GGML_QUANT_SIZES, GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC, GGUF_VERSION, @@ -26,6 +25,8 @@ from .constants import ( TokenType, ) +from .quants import quant_shape_from_byte_shape + logger = logging.getLogger(__name__) @@ -229,10 +230,7 @@ class GGUFWriter: else: dtype = raw_dtype if tensor_dtype == np.uint8: - block_size, type_size = GGML_QUANT_SIZES[raw_dtype] - if tensor_shape[-1] % type_size != 0: - raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})") - tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,) + tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) n_dims = len(tensor_shape) self.ti_data += self._pack("I", n_dims) for i in range(n_dims): |