From b83bab15a5d2a1e7807d09613a9b34309d86cfaa Mon Sep 17 00:00:00 2001 From: compilade Date: Fri, 24 May 2024 21:11:48 -0400 Subject: gguf-py : fix and simplify quantized shape round-trip (#7483) * gguf-py : fix and simplify quantized shape round-trip * gguf-py : remove unused import --- gguf-py/gguf/quants.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'gguf-py/gguf/quants.py') diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index e7fc0eae..b22eec16 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Sequence from numpy.typing import DTypeLike @@ -9,6 +9,20 @@ from .lazy import LazyNumpyTensor import numpy as np +def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % block_size != 0: + raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): + block_size, type_size = GGML_QUANT_SIZES[quant_type] + if shape[-1] % type_size != 0: + raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") + return (*shape[:-1], shape[-1] // type_size * block_size) + + # same as ggml_compute_fp32_to_bf16 in ggml-impl.h def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray: n = n.astype(np.float32, copy=False).view(np.int32) -- cgit v1.2.3