test_quants.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #!/usr/bin/env python3
  2. # Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
  3. # NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
  4. from __future__ import annotations
  5. import argparse
  6. from math import prod
  7. import os
  8. import sys
  9. from pathlib import Path
  10. import ctypes
  11. import logging
  12. import numpy as np
  13. # Necessary to load the local gguf package
  14. if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
  15. sys.path.insert(0, str(Path(__file__).parent.parent))
  16. import gguf
  17. from gguf.constants import GGMLQuantizationType
  18. logger = logging.getLogger("test-quants")
  19. c_float_p = ctypes.POINTER(ctypes.c_float)
  20. class ggml_init_params(ctypes.Structure):
  21. _fields_ = [
  22. ("mem_size", ctypes.c_size_t),
  23. ("mem_buffer", ctypes.c_void_p),
  24. ("no_alloc", ctypes.c_bool),
  25. ]
  26. class GGMLQuants:
  27. libggml: ctypes.CDLL
  28. def __init__(self, libggml: Path):
  29. self.libggml = ctypes.CDLL(str(libggml))
  30. self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
  31. # enum ggml_type type,
  32. # const float * src,
  33. # void * dst,
  34. # int64_t start,
  35. # int64_t nrows,
  36. # int64_t n_per_row,
  37. # const float * imatrix) {
  38. self.libggml.ggml_quantize_chunk.argtypes = (
  39. ctypes.c_int,
  40. ctypes.POINTER(ctypes.c_float),
  41. ctypes.c_void_p,
  42. ctypes.c_int64,
  43. ctypes.c_int64,
  44. ctypes.c_int64,
  45. ctypes.POINTER(ctypes.c_float),
  46. )
  47. self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
  48. self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
  49. for t in (
  50. "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
  51. "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
  52. "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
  53. "iq4_nl", "iq4_xs",
  54. ):
  55. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
  56. dequant_func.restype = None
  57. dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  58. self.libggml.ggml_fp16_to_fp32_row.restype = None
  59. self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  60. self.libggml.ggml_bf16_to_fp32_row.restype = None
  61. self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  62. self.libggml.ggml_init.argtypes = (ggml_init_params,)
  63. self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
  64. def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  65. result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
  66. if qtype == GGMLQuantizationType.F32:
  67. # no-op
  68. result = tensor.view(np.float32)
  69. elif qtype == GGMLQuantizationType.F16:
  70. self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  71. elif qtype == GGMLQuantizationType.BF16:
  72. self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  73. else:
  74. lw_qname = qtype.name.lower()
  75. if lw_qname[-1] == "k":
  76. lw_qname = lw_qname[:-1] + "K"
  77. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
  78. dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
  79. return result
  80. def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  81. result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
  82. if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
  83. # TODO: is a column-wise sum of squares appropriate?
  84. qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
  85. else:
  86. qw = ctypes.cast(0, c_float_p)
  87. result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
  88. assert result.size == result_size
  89. return result
  90. def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
  91. same = np.array_equal(t1, t2)
  92. if same:
  93. return True
  94. else:
  95. block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
  96. if t1.dtype == np.float32:
  97. t1 = t1.reshape((-1, block_size))
  98. t2 = t2.reshape((-1, block_size))
  99. else:
  100. t1 = t1.reshape((-1, type_size))
  101. t2 = t2.reshape((-1, type_size))
  102. x = t1.view(np.uint8) ^ t2.view(np.uint8)
  103. diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
  104. num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
  105. if num_bad_blocks == 0 and t1.shape == t2.shape:
  106. logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
  107. return True
  108. logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
  109. bad_block_id = np.argmax(diff_bits, axis=0)
  110. logger.debug(f"Worst block id: {bad_block_id}")
  111. logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
  112. sum_diff_bits = np.sum(diff_bits)
  113. logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
  114. return False
  115. def do_test(libggml_path: Path, quick: bool = False):
  116. ggml_quants = GGMLQuants(libggml_path)
  117. np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
  118. r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
  119. for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
  120. has_dequantize = False
  121. has_quantize = False
  122. try:
  123. gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
  124. has_dequantize = True
  125. except (NotImplementedError, AssertionError) as e:
  126. if isinstance(e, AssertionError):
  127. logger.error(f"Error with {qtype.name}: {e}")
  128. raise e
  129. try:
  130. gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
  131. has_quantize = True
  132. except (NotImplementedError, AssertionError) as e:
  133. if isinstance(e, AssertionError):
  134. logger.error(f"Error with {qtype.name}: {e}")
  135. raise e
  136. if not has_dequantize and not has_quantize:
  137. continue
  138. logger.info(f"Testing {qtype.name}")
  139. rc = r.copy(order="C")
  140. pyq = None
  141. ggq = None
  142. if has_quantize:
  143. logger.debug(f"Quantizing to {qtype.name} with Python")
  144. pyq = gguf.quants.quantize(rc, qtype)
  145. logger.debug(f"Quantizing to {qtype.name} with C")
  146. ggq = ggml_quants.quantize(rc, qtype)
  147. if qtype == GGMLQuantizationType.F16:
  148. pyq = pyq.view(np.uint8)
  149. quant_equal = compare_tensors(pyq, ggq, qtype)
  150. if not quant_equal:
  151. logger.error(f"Quantization to {qtype.name} does not match ❌")
  152. else:
  153. logger.info(f"Quantization to {qtype.name} matches exactly ✅")
  154. if has_dequantize:
  155. if ggq is None and not quick:
  156. logger.debug(f"Quantizing to {qtype.name} with C")
  157. ggq = ggml_quants.quantize(rc, qtype)
  158. if ggq is not None:
  159. logger.debug(f"Dequantizing from {qtype.name} with Python")
  160. pydq = gguf.quants.dequantize(ggq, qtype)
  161. logger.debug(f"Dequantizing from {qtype.name} with C")
  162. ggdq = ggml_quants.dequantize(ggq, qtype)
  163. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  164. if not dequant_equal:
  165. logger.error(f"Dequantization from {qtype.name} does not match ❌")
  166. else:
  167. logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
  168. rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
  169. rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
  170. logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
  171. pydq = gguf.quants.dequantize(rq, qtype)
  172. logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
  173. ggdq = ggml_quants.dequantize(rq, qtype)
  174. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  175. if not dequant_equal:
  176. logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
  177. else:
  178. logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
  179. if __name__ == "__main__":
  180. parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
  181. parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
  182. parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
  183. args = parser.parse_args()
  184. logging.basicConfig(level=logging.DEBUG)
  185. do_test(args.libggml, args.quick)