test_quants.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #!/usr/bin/env python3
  2. # Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
  3. # NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
  4. from __future__ import annotations
  5. import argparse
  6. from math import prod
  7. import os
  8. import sys
  9. from pathlib import Path
  10. import ctypes
  11. import logging
  12. import numpy as np
  13. # Necessary to load the local gguf package
  14. if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
  15. sys.path.insert(0, str(Path(__file__).parent.parent))
  16. import gguf
  17. from gguf.constants import GGMLQuantizationType
  18. logger = logging.getLogger("test-quants")
  19. c_float_p = ctypes.POINTER(ctypes.c_float)
  20. class ggml_init_params(ctypes.Structure):
  21. _fields_ = [
  22. ("mem_size", ctypes.c_size_t),
  23. ("mem_buffer", ctypes.c_void_p),
  24. ("no_alloc", ctypes.c_bool),
  25. ]
  26. class GGMLQuants:
  27. libggml: ctypes.CDLL
  28. def __init__(self, libggml: Path):
  29. self.libggml = ctypes.CDLL(str(libggml))
  30. self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
  31. # enum ggml_type type,
  32. # const float * src,
  33. # void * dst,
  34. # int64_t start,
  35. # int64_t nrows,
  36. # int64_t n_per_row,
  37. # const float * imatrix) {
  38. self.libggml.ggml_quantize_chunk.argtypes = (
  39. ctypes.c_int,
  40. ctypes.POINTER(ctypes.c_float),
  41. ctypes.c_void_p,
  42. ctypes.c_int64,
  43. ctypes.c_int64,
  44. ctypes.c_int64,
  45. ctypes.POINTER(ctypes.c_float),
  46. )
  47. self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
  48. self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
  49. for t in (
  50. "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
  51. "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
  52. "tq1_0", "tq2_0",
  53. "mxfp4",
  54. "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
  55. "iq4_nl", "iq4_xs",
  56. ):
  57. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
  58. dequant_func.restype = None
  59. dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  60. self.libggml.ggml_fp16_to_fp32_row.restype = None
  61. self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  62. self.libggml.ggml_bf16_to_fp32_row.restype = None
  63. self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  64. self.libggml.ggml_init.argtypes = (ggml_init_params,)
  65. self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
  66. def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  67. result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
  68. if qtype == GGMLQuantizationType.F32:
  69. # no-op
  70. result = tensor.view(np.float32)
  71. elif qtype == GGMLQuantizationType.F16:
  72. self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  73. elif qtype == GGMLQuantizationType.BF16:
  74. self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  75. else:
  76. lw_qname = qtype.name.lower()
  77. if lw_qname[-1] == "k":
  78. lw_qname = lw_qname[:-1] + "K"
  79. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
  80. dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
  81. return result
  82. def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  83. result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
  84. if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
  85. # TODO: is a column-wise sum of squares appropriate?
  86. qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
  87. else:
  88. qw = ctypes.cast(0, c_float_p)
  89. result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
  90. assert result.size == result_size
  91. return result
  92. def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
  93. same = np.array_equal(t1, t2)
  94. if same:
  95. return True
  96. else:
  97. block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
  98. if t1.dtype == np.float32:
  99. t1 = t1.reshape((-1, block_size))
  100. t2 = t2.reshape((-1, block_size))
  101. else:
  102. t1 = t1.reshape((-1, type_size))
  103. t2 = t2.reshape((-1, type_size))
  104. x = t1.view(np.uint8) ^ t2.view(np.uint8)
  105. diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
  106. num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
  107. if num_bad_blocks == 0 and t1.shape == t2.shape:
  108. logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
  109. return True
  110. logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
  111. bad_block_id = np.argmax(diff_bits, axis=0)
  112. logger.debug(f"Worst block id: {bad_block_id}")
  113. logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
  114. sum_diff_bits = np.sum(diff_bits)
  115. logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
  116. return False
  117. def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
  118. ggml_quants = GGMLQuants(libggml_path)
  119. np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
  120. r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
  121. # test zero blocks
  122. r[0, 0, :] = 0
  123. ## Maybe test infinities? (can make NANs, not really useful in practice)
  124. # r[0, 1, 0] = np.inf
  125. # r[0, 2, 0] = -np.inf
  126. # r[0, 3, 0] = np.inf
  127. # r[0, 3, 1] = -np.inf
  128. for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
  129. has_dequantize = False
  130. has_quantize = False
  131. try:
  132. gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
  133. has_dequantize = True
  134. except (NotImplementedError, AssertionError) as e:
  135. if isinstance(e, AssertionError):
  136. logger.error(f"Error with {qtype.name}: {e}")
  137. raise e
  138. try:
  139. gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
  140. has_quantize = True
  141. except (NotImplementedError, AssertionError) as e:
  142. if isinstance(e, AssertionError):
  143. logger.error(f"Error with {qtype.name}: {e}")
  144. raise e
  145. if not has_dequantize and not has_quantize:
  146. continue
  147. logger.info(f"Testing {qtype.name}")
  148. rc = r.copy(order="C")
  149. pyq = None
  150. ggq = None
  151. if has_quantize:
  152. logger.debug(f"Quantizing to {qtype.name} with Python")
  153. pyq = gguf.quants.quantize(rc, qtype)
  154. logger.debug(f"Quantizing to {qtype.name} with C")
  155. ggq = ggml_quants.quantize(rc, qtype)
  156. if qtype == GGMLQuantizationType.F16:
  157. pyq = pyq.view(np.uint8)
  158. quant_equal = compare_tensors(pyq, ggq, qtype)
  159. if not quant_equal:
  160. logger.error(f"Quantization to {qtype.name} does not match ❌")
  161. else:
  162. logger.info(f"Quantization to {qtype.name} matches exactly ✅")
  163. if has_dequantize:
  164. if ggq is None and not quick:
  165. logger.debug(f"Quantizing to {qtype.name} with C")
  166. ggq = ggml_quants.quantize(rc, qtype)
  167. if ggq is not None:
  168. logger.debug(f"Dequantizing from {qtype.name} with Python")
  169. pydq = gguf.quants.dequantize(ggq, qtype)
  170. logger.debug(f"Dequantizing from {qtype.name} with C")
  171. ggdq = ggml_quants.dequantize(ggq, qtype)
  172. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  173. if not dequant_equal:
  174. logger.error(f"Dequantization from {qtype.name} does not match ❌")
  175. else:
  176. logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
  177. rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
  178. rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
  179. logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
  180. pydq = gguf.quants.dequantize(rq, qtype)
  181. logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
  182. ggdq = ggml_quants.dequantize(rq, qtype)
  183. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  184. if not dequant_equal:
  185. logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
  186. else:
  187. logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
  188. if __name__ == "__main__":
  189. parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
  190. parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
  191. parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
  192. parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
  193. args = parser.parse_args()
  194. logging.basicConfig(level=logging.DEBUG)
  195. do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)