summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/quants.py
diff options
context:
space:
mode:
authorcompilade <git@compilade.net>2024-05-24 21:11:48 -0400
committerGitHub <noreply@github.com>2024-05-25 11:11:48 +1000
commitb83bab15a5d2a1e7807d09613a9b34309d86cfaa (patch)
tree449b4201f8b8929f674fc2ad7654406ba2c50a4b /gguf-py/gguf/quants.py
parentd041d2ceaaf50e058622d92921b3e680ffa4e9e7 (diff)
gguf-py : fix and simplify quantized shape round-trip (#7483)
* gguf-py : fix and simplify quantized shape round-trip * gguf-py : remove unused import
Diffstat (limited to 'gguf-py/gguf/quants.py')
-rw-r--r--gguf-py/gguf/quants.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py
index e7fc0eae..b22eec16 100644
--- a/gguf-py/gguf/quants.py
+++ b/gguf-py/gguf/quants.py
@@ -1,5 +1,5 @@
from __future__ import annotations
-from typing import Callable
+from typing import Callable, Sequence
from numpy.typing import DTypeLike
@@ -9,6 +9,20 @@ from .lazy import LazyNumpyTensor
import numpy as np
+def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+ if shape[-1] % block_size != 0:
+ raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
+ return (*shape[:-1], shape[-1] // block_size * type_size)
+
+
+def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+ if shape[-1] % type_size != 0:
+ raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
+ return (*shape[:-1], shape[-1] // type_size * block_size)
+
+
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
n = n.astype(np.float32, copy=False).view(np.int32)