summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r--gguf-py/gguf/gguf_writer.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 8b41b54e..c194dd5d 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -13,7 +13,6 @@ from string import ascii_letters, digits
import numpy as np
from .constants import (
- GGML_QUANT_SIZES,
GGUF_DEFAULT_ALIGNMENT,
GGUF_MAGIC,
GGUF_VERSION,
@@ -26,6 +25,8 @@ from .constants import (
TokenType,
)
+from .quants import quant_shape_from_byte_shape
+
logger = logging.getLogger(__name__)
@@ -229,10 +230,7 @@ class GGUFWriter:
else:
dtype = raw_dtype
if tensor_dtype == np.uint8:
- block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
- if tensor_shape[-1] % type_size != 0:
- raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
- tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
+ tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims)
for i in range(n_dims):