test_quants.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #!/usr/bin/env python3
  2. # Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
  3. # NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
  4. from __future__ import annotations
  5. import argparse
  6. from math import prod
  7. import os
  8. import sys
  9. from pathlib import Path
  10. import ctypes
  11. import logging
  12. import numpy as np
  13. # Necessary to load the local gguf package
  14. if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
  15. sys.path.insert(0, str(Path(__file__).parent.parent))
  16. import gguf
  17. from gguf.constants import GGMLQuantizationType
  18. logger = logging.getLogger("test-quants")
  19. c_float_p = ctypes.POINTER(ctypes.c_float)
  20. class ggml_init_params(ctypes.Structure):
  21. _fields_ = [
  22. ("mem_size", ctypes.c_size_t),
  23. ("mem_buffer", ctypes.c_void_p),
  24. ("no_alloc", ctypes.c_bool),
  25. ]
  26. class GGMLQuants:
  27. libggml: ctypes.CDLL
  28. def __init__(self, libggml: Path):
  29. self.libggml = ctypes.CDLL(str(libggml))
  30. self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
  31. # enum ggml_type type,
  32. # const float * src,
  33. # void * dst,
  34. # int64_t start,
  35. # int64_t nrows,
  36. # int64_t n_per_row,
  37. # const float * imatrix) {
  38. self.libggml.ggml_quantize_chunk.argtypes = (
  39. ctypes.c_int,
  40. ctypes.POINTER(ctypes.c_float),
  41. ctypes.c_void_p,
  42. ctypes.c_int64,
  43. ctypes.c_int64,
  44. ctypes.c_int64,
  45. ctypes.POINTER(ctypes.c_float),
  46. )
  47. self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
  48. self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
  49. for t in (
  50. "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
  51. "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
  52. "tq1_0", "tq2_0",
  53. "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
  54. "iq4_nl", "iq4_xs",
  55. ):
  56. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
  57. dequant_func.restype = None
  58. dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  59. self.libggml.ggml_fp16_to_fp32_row.restype = None
  60. self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  61. self.libggml.ggml_bf16_to_fp32_row.restype = None
  62. self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
  63. self.libggml.ggml_init.argtypes = (ggml_init_params,)
  64. self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
  65. def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  66. result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
  67. if qtype == GGMLQuantizationType.F32:
  68. # no-op
  69. result = tensor.view(np.float32)
  70. elif qtype == GGMLQuantizationType.F16:
  71. self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  72. elif qtype == GGMLQuantizationType.BF16:
  73. self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
  74. else:
  75. lw_qname = qtype.name.lower()
  76. if lw_qname[-1] == "k":
  77. lw_qname = lw_qname[:-1] + "K"
  78. dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
  79. dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
  80. return result
  81. def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  82. result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
  83. if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
  84. # TODO: is a column-wise sum of squares appropriate?
  85. qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
  86. else:
  87. qw = ctypes.cast(0, c_float_p)
  88. result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
  89. assert result.size == result_size
  90. return result
  91. def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
  92. same = np.array_equal(t1, t2)
  93. if same:
  94. return True
  95. else:
  96. block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
  97. if t1.dtype == np.float32:
  98. t1 = t1.reshape((-1, block_size))
  99. t2 = t2.reshape((-1, block_size))
  100. else:
  101. t1 = t1.reshape((-1, type_size))
  102. t2 = t2.reshape((-1, type_size))
  103. x = t1.view(np.uint8) ^ t2.view(np.uint8)
  104. diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
  105. num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
  106. if num_bad_blocks == 0 and t1.shape == t2.shape:
  107. logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
  108. return True
  109. logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
  110. bad_block_id = np.argmax(diff_bits, axis=0)
  111. logger.debug(f"Worst block id: {bad_block_id}")
  112. logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
  113. sum_diff_bits = np.sum(diff_bits)
  114. logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
  115. return False
  116. def do_test(libggml_path: Path, quick: bool = False):
  117. ggml_quants = GGMLQuants(libggml_path)
  118. np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
  119. r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
  120. for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
  121. has_dequantize = False
  122. has_quantize = False
  123. try:
  124. gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
  125. has_dequantize = True
  126. except (NotImplementedError, AssertionError) as e:
  127. if isinstance(e, AssertionError):
  128. logger.error(f"Error with {qtype.name}: {e}")
  129. raise e
  130. try:
  131. gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
  132. has_quantize = True
  133. except (NotImplementedError, AssertionError) as e:
  134. if isinstance(e, AssertionError):
  135. logger.error(f"Error with {qtype.name}: {e}")
  136. raise e
  137. if not has_dequantize and not has_quantize:
  138. continue
  139. logger.info(f"Testing {qtype.name}")
  140. rc = r.copy(order="C")
  141. pyq = None
  142. ggq = None
  143. if has_quantize:
  144. logger.debug(f"Quantizing to {qtype.name} with Python")
  145. pyq = gguf.quants.quantize(rc, qtype)
  146. logger.debug(f"Quantizing to {qtype.name} with C")
  147. ggq = ggml_quants.quantize(rc, qtype)
  148. if qtype == GGMLQuantizationType.F16:
  149. pyq = pyq.view(np.uint8)
  150. quant_equal = compare_tensors(pyq, ggq, qtype)
  151. if not quant_equal:
  152. logger.error(f"Quantization to {qtype.name} does not match ❌")
  153. else:
  154. logger.info(f"Quantization to {qtype.name} matches exactly ✅")
  155. if has_dequantize:
  156. if ggq is None and not quick:
  157. logger.debug(f"Quantizing to {qtype.name} with C")
  158. ggq = ggml_quants.quantize(rc, qtype)
  159. if ggq is not None:
  160. logger.debug(f"Dequantizing from {qtype.name} with Python")
  161. pydq = gguf.quants.dequantize(ggq, qtype)
  162. logger.debug(f"Dequantizing from {qtype.name} with C")
  163. ggdq = ggml_quants.dequantize(ggq, qtype)
  164. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  165. if not dequant_equal:
  166. logger.error(f"Dequantization from {qtype.name} does not match ❌")
  167. else:
  168. logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
  169. rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
  170. rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
  171. logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
  172. pydq = gguf.quants.dequantize(rq, qtype)
  173. logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
  174. ggdq = ggml_quants.dequantize(rq, qtype)
  175. dequant_equal = compare_tensors(pydq, ggdq, qtype)
  176. if not dequant_equal:
  177. logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
  178. else:
  179. logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
  180. if __name__ == "__main__":
  181. parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
  182. parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
  183. parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
  184. args = parser.parse_args()
  185. logging.basicConfig(level=logging.DEBUG)
  186. do_test(args.libggml, args.quick)