summaryrefslogtreecommitdiff
path: root/gguf-py/tests/test_quants.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/tests/test_quants.py')
-rwxr-xr-xgguf-py/tests/test_quants.py237
1 files changed, 237 insertions, 0 deletions
diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py
new file mode 100755
index 00000000..8b7a85c2
--- /dev/null
+++ b/gguf-py/tests/test_quants.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python3
+
+# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
+
+# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
+
+from __future__ import annotations
+
+import argparse
+from math import prod
+import os
+import sys
+from pathlib import Path
+import ctypes
+import logging
+import numpy as np
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import gguf
+from gguf.constants import GGMLQuantizationType
+
+
+logger = logging.getLogger("test-quants")
+
+
+c_float_p = ctypes.POINTER(ctypes.c_float)
+
+
+class ggml_init_params(ctypes.Structure):
+ _fields_ = [
+ ("mem_size", ctypes.c_size_t),
+ ("mem_buffer", ctypes.c_void_p),
+ ("no_alloc", ctypes.c_bool),
+ ]
+
+
+class GGMLQuants:
+ libggml: ctypes.CDLL
+
+ def __init__(self, libggml: Path):
+ self.libggml = ctypes.CDLL(str(libggml))
+ self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
+ # enum ggml_type type,
+ # const float * src,
+ # void * dst,
+ # int64_t start,
+ # int64_t nrows,
+ # int64_t n_per_row,
+ # const float * imatrix) {
+ self.libggml.ggml_quantize_chunk.argtypes = (
+ ctypes.c_int,
+ ctypes.POINTER(ctypes.c_float),
+ ctypes.c_void_p,
+ ctypes.c_int64,
+ ctypes.c_int64,
+ ctypes.c_int64,
+ ctypes.POINTER(ctypes.c_float),
+ )
+
+ self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
+ self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
+
+ for t in (
+ "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
+ "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
+ "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
+ "iq4_nl", "iq4_xs",
+ ):
+ dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
+ dequant_func.restype = None
+ dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
+
+ self.libggml.ggml_fp16_to_fp32_row.restype = None
+ self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
+ self.libggml.ggml_bf16_to_fp32_row.restype = None
+ self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
+
+ self.libggml.ggml_init.argtypes = (ggml_init_params,)
+
+ self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
+
+ def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
+ if qtype == GGMLQuantizationType.F32:
+ # no-op
+ result = tensor.view(np.float32)
+ elif qtype == GGMLQuantizationType.F16:
+ self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
+ elif qtype == GGMLQuantizationType.BF16:
+ self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
+ else:
+ lw_qname = qtype.name.lower()
+ if lw_qname[-1] == "k":
+ lw_qname = lw_qname[:-1] + "K"
+ dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
+ dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
+ return result
+
+ def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
+ if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
+ # TODO: is a column-wise sum of squares appropriate?
+ qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
+ else:
+ qw = ctypes.cast(0, c_float_p)
+ result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
+ assert result.size == result_size
+ return result
+
+
+def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
+ same = np.array_equal(t1, t2)
+ if same:
+ return True
+ else:
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
+ if t1.dtype == np.float32:
+ t1 = t1.reshape((-1, block_size))
+ t2 = t2.reshape((-1, block_size))
+ else:
+ t1 = t1.reshape((-1, type_size))
+ t2 = t2.reshape((-1, type_size))
+ x = t1.view(np.uint8) ^ t2.view(np.uint8)
+ diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
+ num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
+ if num_bad_blocks == 0 and t1.shape == t2.shape:
+ logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
+ return True
+ logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
+ bad_block_id = np.argmax(diff_bits, axis=0)
+ logger.debug(f"Worst block id: {bad_block_id}")
+ logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
+
+ sum_diff_bits = np.sum(diff_bits)
+ logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
+ return False
+
+
+def do_test(libggml_path: Path, quick: bool = False):
+ ggml_quants = GGMLQuants(libggml_path)
+
+ np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
+
+ r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
+
+ for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
+ has_dequantize = False
+ has_quantize = False
+
+ try:
+ gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
+ has_dequantize = True
+ except (NotImplementedError, AssertionError) as e:
+ if isinstance(e, AssertionError):
+ logger.error(f"Error with {qtype.name}: {e}")
+ raise e
+ try:
+ gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
+ has_quantize = True
+ except (NotImplementedError, AssertionError) as e:
+ if isinstance(e, AssertionError):
+ logger.error(f"Error with {qtype.name}: {e}")
+ raise e
+
+ if not has_dequantize and not has_quantize:
+ continue
+
+ logger.info(f"Testing {qtype.name}")
+
+ rc = r.copy(order="C")
+
+ pyq = None
+ ggq = None
+
+ if has_quantize:
+ logger.debug(f"Quantizing to {qtype.name} with Python")
+ pyq = gguf.quants.quantize(rc, qtype)
+
+ logger.debug(f"Quantizing to {qtype.name} with C")
+ ggq = ggml_quants.quantize(rc, qtype)
+
+ if qtype == GGMLQuantizationType.F16:
+ pyq = pyq.view(np.uint8)
+ quant_equal = compare_tensors(pyq, ggq, qtype)
+
+ if not quant_equal:
+ logger.error(f"Quantization to {qtype.name} does not match ❌")
+ else:
+ logger.info(f"Quantization to {qtype.name} matches exactly ✅")
+
+ if has_dequantize:
+ if ggq is None and not quick:
+ logger.debug(f"Quantizing to {qtype.name} with C")
+ ggq = ggml_quants.quantize(rc, qtype)
+
+ if ggq is not None:
+ logger.debug(f"Dequantizing from {qtype.name} with Python")
+ pydq = gguf.quants.dequantize(ggq, qtype)
+ logger.debug(f"Dequantizing from {qtype.name} with C")
+ ggdq = ggml_quants.dequantize(ggq, qtype)
+
+ dequant_equal = compare_tensors(pydq, ggdq, qtype)
+
+ if not dequant_equal:
+ logger.error(f"Dequantization from {qtype.name} does not match ❌")
+ else:
+ logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
+
+ rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
+ rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
+
+ logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
+ pydq = gguf.quants.dequantize(rq, qtype)
+ logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
+ ggdq = ggml_quants.dequantize(rq, qtype)
+
+ dequant_equal = compare_tensors(pydq, ggdq, qtype)
+
+ if not dequant_equal:
+ logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
+ else:
+ logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
+ parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
+ parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
+
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.DEBUG)
+
+ do_test(args.libggml, args.quick)