diff options
author | compilade <git@compilade.net> | 2024-05-24 21:11:48 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-25 11:11:48 +1000 |
commit | b83bab15a5d2a1e7807d09613a9b34309d86cfaa (patch) | |
tree | 449b4201f8b8929f674fc2ad7654406ba2c50a4b /gguf-py/gguf/quants.py | |
parent | d041d2ceaaf50e058622d92921b3e680ffa4e9e7 (diff) |
gguf-py : fix and simplify quantized shape round-trip (#7483)
* gguf-py : fix and simplify quantized shape round-trip
* gguf-py : remove unused import
Diffstat (limited to 'gguf-py/gguf/quants.py')
-rw-r--r-- | gguf-py/gguf/quants.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index e7fc0eae..b22eec16 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Sequence from numpy.typing import DTypeLike @@ -9,6 +9,20 @@ from .lazy import LazyNumpyTensor import numpy as np +def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % block_size != 0: + raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % type_size != 0: + raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") + return (*shape[:-1], shape[-1] // type_size * block_size) + + # same as ggml_compute_fp32_to_bf16 in ggml-impl.h def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray: n = n.astype(np.float32, copy=False).view(np.int32) |