quants.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Any, Callable, Sequence
  4. from math import log2, ceil
  5. from numpy.typing import DTypeLike
  6. from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
  7. from .lazy import LazyNumpyTensor
  8. import numpy as np
  9. def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
  10. block_size, type_size = GGML_QUANT_SIZES[quant_type]
  11. if shape[-1] % block_size != 0:
  12. raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
  13. return (*shape[:-1], shape[-1] // block_size * type_size)
  14. def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
  15. block_size, type_size = GGML_QUANT_SIZES[quant_type]
  16. if shape[-1] % type_size != 0:
  17. raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
  18. return (*shape[:-1], shape[-1] // type_size * block_size)
  19. # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
  20. def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
  21. rows = arr.reshape((-1, arr.shape[-1]))
  22. osize = 1
  23. for dim in oshape:
  24. osize *= dim
  25. out = np.empty(shape=osize, dtype=otype)
  26. # compute over groups of 16 rows (arbitrary, but seems good for performance)
  27. n_groups = (rows.shape[0] // 16) or 1
  28. np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
  29. return out.reshape(oshape)
  30. # round away from zero
  31. # ref: https://stackoverflow.com/a/59143326/22827863
  32. def np_roundf(n: np.ndarray) -> np.ndarray:
  33. a = abs(n)
  34. floored = np.floor(a)
  35. b = floored + np.floor(2 * (a - floored))
  36. return np.sign(n) * b
  37. class QuantError(Exception): ...
  38. _type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
  39. def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  40. if qtype == GGMLQuantizationType.F32:
  41. return data.astype(np.float32, copy=False)
  42. elif qtype == GGMLQuantizationType.F16:
  43. return data.astype(np.float16, copy=False)
  44. elif (q := _type_traits.get(qtype)) is not None:
  45. return q.quantize(data)
  46. else:
  47. raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
  48. def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
  49. if qtype == GGMLQuantizationType.F32:
  50. return data.view(np.float32)
  51. elif qtype == GGMLQuantizationType.F16:
  52. return data.view(np.float16).astype(np.float32)
  53. elif (q := _type_traits.get(qtype)) is not None:
  54. return q.dequantize(data)
  55. else:
  56. raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
  57. class __Quant(ABC):
  58. qtype: GGMLQuantizationType
  59. block_size: int
  60. type_size: int
  61. grid: np.ndarray[Any, np.dtype[np.float32]] | None = None
  62. grid_shape: tuple[int, int] = (0, 0)
  63. grid_map: tuple[int | float, ...] = ()
  64. grid_hex: bytes | None = None
  65. def __init__(self):
  66. return TypeError("Quant conversion classes can't have instances")
  67. def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
  68. cls.qtype = qtype
  69. cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
  70. cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
  71. cls.__quantize_array,
  72. meta_noop=(np.uint8, cls.__shape_to_bytes)
  73. )
  74. cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
  75. cls.__dequantize_array,
  76. meta_noop=(np.float32, cls.__shape_from_bytes)
  77. )
  78. assert qtype not in _type_traits
  79. _type_traits[qtype] = cls
  80. @classmethod
  81. def init_grid(cls):
  82. if cls.grid is not None or cls.grid_hex is None:
  83. return
  84. bits_per_elem = ceil(log2(len(cls.grid_map)))
  85. assert bits_per_elem != 0, cls.qtype.name
  86. elems_per_byte = 8 // bits_per_elem
  87. grid = np.frombuffer(cls.grid_hex, dtype=np.uint8)
  88. # decode hexadecimal chars from grid
  89. grid = grid.reshape((-1, 2))
  90. grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2))
  91. grid = grid[..., 0] | grid[..., 1]
  92. # unpack the grid values
  93. grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte))
  94. grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1))
  95. grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1))
  96. grid = np.take_along_axis(grid_map, grid, axis=-1)
  97. cls.grid = grid.reshape((1, 1, *cls.grid_shape))
  98. @classmethod
  99. @abstractmethod
  100. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  101. raise NotImplementedError
  102. @classmethod
  103. @abstractmethod
  104. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  105. raise NotImplementedError
  106. @classmethod
  107. def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
  108. rows = rows.astype(np.float32, copy=False)
  109. shape = rows.shape
  110. n_blocks = rows.size // cls.block_size
  111. blocks = rows.reshape((n_blocks, cls.block_size))
  112. blocks = cls.quantize_blocks(blocks)
  113. assert blocks.dtype == np.uint8
  114. assert blocks.shape[-1] == cls.type_size
  115. return blocks.reshape(cls.__shape_to_bytes(shape))
  116. @classmethod
  117. def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
  118. rows = rows.view(np.uint8)
  119. shape = rows.shape
  120. n_blocks = rows.size // cls.type_size
  121. blocks = rows.reshape((n_blocks, cls.type_size))
  122. blocks = cls.dequantize_blocks(blocks)
  123. assert blocks.dtype == np.float32
  124. assert blocks.shape[-1] == cls.block_size
  125. return blocks.reshape(cls.__shape_from_bytes(shape))
  126. @classmethod
  127. def __shape_to_bytes(cls, shape: Sequence[int]):
  128. return quant_shape_to_byte_shape(shape, cls.qtype)
  129. @classmethod
  130. def __shape_from_bytes(cls, shape: Sequence[int]):
  131. return quant_shape_from_byte_shape(shape, cls.qtype)
  132. @classmethod
  133. def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
  134. return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
  135. @classmethod
  136. def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
  137. cls.init_grid()
  138. return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
  139. @classmethod
  140. def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
  141. pass
  142. @classmethod
  143. def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
  144. pass
  145. @classmethod
  146. def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
  147. return tensor.shape[-1] % cls.block_size == 0
  148. @classmethod
  149. def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
  150. if not cls.can_quantize(tensor):
  151. raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
  152. if isinstance(tensor, LazyNumpyTensor):
  153. return cls.__quantize_lazy(tensor)
  154. else:
  155. return cls.__quantize_array(tensor)
  156. @classmethod
  157. def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
  158. if isinstance(tensor, LazyNumpyTensor):
  159. return cls.__dequantize_lazy(tensor)
  160. else:
  161. return cls.__dequantize_array(tensor)
  162. class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
  163. @classmethod
  164. # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
  165. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  166. n = blocks.view(np.uint32)
  167. # force nan to quiet
  168. n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
  169. # round to nearest even
  170. n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
  171. return n.astype(np.uint16).view(np.uint8)
  172. @classmethod
  173. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  174. return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
  175. class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
  176. @classmethod
  177. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  178. n_blocks = blocks.shape[0]
  179. imax = abs(blocks).argmax(axis=-1, keepdims=True)
  180. max = np.take_along_axis(blocks, imax, axis=-1)
  181. d = max / -8
  182. with np.errstate(divide="ignore"):
  183. id = np.where(d == 0, 0, 1 / d)
  184. qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
  185. qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
  186. qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
  187. d = d.astype(np.float16).view(np.uint8)
  188. return np.concatenate([d, qs], axis=-1)
  189. @classmethod
  190. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  191. n_blocks = blocks.shape[0]
  192. d, qs = np.hsplit(blocks, [2])
  193. d = d.view(np.float16).astype(np.float32)
  194. qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  195. qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8)
  196. return (d * qs.astype(np.float32))
  197. class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
  198. @classmethod
  199. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  200. n_blocks = blocks.shape[0]
  201. max = blocks.max(axis=-1, keepdims=True)
  202. min = blocks.min(axis=-1, keepdims=True)
  203. d = (max - min) / 15
  204. with np.errstate(divide="ignore"):
  205. id = np.where(d == 0, 0, 1 / d)
  206. qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
  207. qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
  208. qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
  209. d = d.astype(np.float16).view(np.uint8)
  210. m = min.astype(np.float16).view(np.uint8)
  211. return np.concatenate([d, m, qs], axis=-1)
  212. @classmethod
  213. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  214. n_blocks = blocks.shape[0]
  215. d, rest = np.hsplit(blocks, [2])
  216. m, qs = np.hsplit(rest, [2])
  217. d = d.view(np.float16).astype(np.float32)
  218. m = m.view(np.float16).astype(np.float32)
  219. qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  220. qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32)
  221. return (d * qs) + m
  222. class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
  223. @classmethod
  224. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  225. n_blocks = blocks.shape[0]
  226. imax = abs(blocks).argmax(axis=-1, keepdims=True)
  227. max = np.take_along_axis(blocks, imax, axis=-1)
  228. d = max / -16
  229. with np.errstate(divide="ignore"):
  230. id = np.where(d == 0, 0, 1 / d)
  231. q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
  232. qs = q.reshape((n_blocks, 2, cls.block_size // 2))
  233. qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
  234. qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
  235. d = d.astype(np.float16).view(np.uint8)
  236. return np.concatenate([d, qh, qs], axis=-1)
  237. @classmethod
  238. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  239. n_blocks = blocks.shape[0]
  240. d, rest = np.hsplit(blocks, [2])
  241. qh, qs = np.hsplit(rest, [4])
  242. d = d.view(np.float16).astype(np.float32)
  243. qh = qh.view(np.uint32)
  244. qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
  245. ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  246. qh = (qh & np.uint32(0x01)).astype(np.uint8)
  247. ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
  248. qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16)
  249. return (d * qs.astype(np.float32))
  250. class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
  251. @classmethod
  252. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  253. n_blocks = blocks.shape[0]
  254. max = blocks.max(axis=-1, keepdims=True)
  255. min = blocks.min(axis=-1, keepdims=True)
  256. d = (max - min) / 31
  257. with np.errstate(divide="ignore"):
  258. id = np.where(d == 0, 0, 1 / d)
  259. q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
  260. qs = q.reshape((n_blocks, 2, cls.block_size // 2))
  261. qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
  262. qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
  263. d = d.astype(np.float16).view(np.uint8)
  264. m = min.astype(np.float16).view(np.uint8)
  265. return np.concatenate([d, m, qh, qs], axis=-1)
  266. @classmethod
  267. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  268. n_blocks = blocks.shape[0]
  269. d, rest = np.hsplit(blocks, [2])
  270. m, rest = np.hsplit(rest, [2])
  271. qh, qs = np.hsplit(rest, [4])
  272. d = d.view(np.float16).astype(np.float32)
  273. m = m.view(np.float16).astype(np.float32)
  274. qh = qh.view(np.uint32)
  275. qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
  276. ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  277. qh = (qh & np.uint32(0x01)).astype(np.uint8)
  278. ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
  279. qs = (ql | (qh << np.uint8(4))).astype(np.float32)
  280. return (d * qs) + m
  281. class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
  282. @classmethod
  283. # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
  284. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  285. d = abs(blocks).max(axis=1, keepdims=True) / 127
  286. with np.errstate(divide="ignore"):
  287. id = np.where(d == 0, 0, 1 / d)
  288. qs = np_roundf(blocks * id)
  289. # (n_blocks, 2)
  290. d = d.astype(np.float16).view(np.uint8)
  291. # (n_blocks, block_size)
  292. qs = qs.astype(np.int8).view(np.uint8)
  293. return np.concatenate([d, qs], axis=1)
  294. @classmethod
  295. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  296. d, x = np.split(blocks, [2], axis=1)
  297. d = d.view(np.float16).astype(np.float32)
  298. x = x.view(np.int8).astype(np.float32)
  299. return (x * d)
  300. class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
  301. @classmethod
  302. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  303. n_blocks = blocks.shape[0]
  304. scales, rest = np.hsplit(blocks, [QK_K // 16])
  305. qs, rest = np.hsplit(rest, [QK_K // 4])
  306. d, dmin = np.hsplit(rest, [2])
  307. d = d.view(np.float16).astype(np.float32)
  308. dmin = dmin.view(np.float16).astype(np.float32)
  309. # (n_blocks, 16, 1)
  310. dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
  311. ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
  312. shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
  313. qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & np.uint8(3)
  314. qs = qs.reshape((n_blocks, QK_K // 16, 16)).astype(np.float32)
  315. qs = dl * qs - ml
  316. return qs.reshape((n_blocks, -1))
  317. class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
  318. @classmethod
  319. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  320. n_blocks = blocks.shape[0]
  321. hmask, rest = np.hsplit(blocks, [QK_K // 8])
  322. qs, rest = np.hsplit(rest, [QK_K // 4])
  323. scales, d = np.hsplit(rest, [12])
  324. d = d.view(np.float16).astype(np.float32)
  325. # The scales are packed at 6-bit each in this pattern:
  326. # 0: IIIIAAAA
  327. # 1: JJJJBBBB
  328. # 2: KKKKCCCC
  329. # 3: LLLLDDDD
  330. # 4: MMMMEEEE
  331. # 5: NNNNFFFF
  332. # 6: OOOOGGGG
  333. # 7: PPPPHHHH
  334. # 8: MMIIEEAA
  335. # 9: NNJJFFBB
  336. # 10: OOKKGGCC
  337. # 11: PPLLHHDD
  338. lscales, hscales = np.hsplit(scales, [8])
  339. lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
  340. lscales = lscales.reshape((n_blocks, 16))
  341. hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1))
  342. hscales = hscales.reshape((n_blocks, 16))
  343. scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4))
  344. scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32)
  345. dl = (d * scales).reshape((n_blocks, 16, 1))
  346. ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
  347. qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
  348. ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3)
  349. qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1))
  350. qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1
  351. q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32)
  352. return (dl * q).reshape((n_blocks, QK_K))
  353. class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
  354. K_SCALE_SIZE = 12
  355. @staticmethod
  356. def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
  357. n_blocks = scales.shape[0]
  358. scales = scales.view(np.uint8)
  359. ### Unpacking the following: ###
  360. # 0 EEAAAAAA
  361. # 1 FFBBBBBB
  362. # 2 GGCCCCCC
  363. # 3 HHDDDDDD
  364. # 4 eeaaaaaa
  365. # 5 ffbbbbbb
  366. # 6 ggcccccc
  367. # 7 hhdddddd
  368. # 8 eeeeEEEE
  369. # 9 ffffFFFF
  370. # 10 ggggGGGG
  371. # 11 hhhhHHHH
  372. scales = scales.reshape((n_blocks, 3, 4))
  373. d, m, m_d = np.split(scales, 3, axis=-2)
  374. sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1)
  375. min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1)
  376. return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
  377. @classmethod
  378. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  379. n_blocks = blocks.shape[0]
  380. d, rest = np.hsplit(blocks, [2])
  381. dmin, rest = np.hsplit(rest, [2])
  382. scales, qs = np.hsplit(rest, [cls.K_SCALE_SIZE])
  383. d = d.view(np.float16).astype(np.float32)
  384. dmin = dmin.view(np.float16).astype(np.float32)
  385. sc, m = Q4_K.get_scale_min(scales)
  386. d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
  387. dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
  388. qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  389. qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32)
  390. return (d * qs - dm).reshape((n_blocks, QK_K))
  391. class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
  392. @classmethod
  393. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  394. n_blocks = blocks.shape[0]
  395. d, rest = np.hsplit(blocks, [2])
  396. dmin, rest = np.hsplit(rest, [2])
  397. scales, rest = np.hsplit(rest, [Q4_K.K_SCALE_SIZE])
  398. qh, qs = np.hsplit(rest, [QK_K // 8])
  399. d = d.view(np.float16).astype(np.float32)
  400. dmin = dmin.view(np.float16).astype(np.float32)
  401. sc, m = Q4_K.get_scale_min(scales)
  402. d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
  403. dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
  404. ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  405. qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
  406. ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
  407. qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32))
  408. q = (ql | (qh << np.uint8(4))).astype(np.float32)
  409. return (d * q - dm).reshape((n_blocks, QK_K))
  410. class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
  411. @classmethod
  412. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  413. n_blocks = blocks.shape[0]
  414. ql, rest = np.hsplit(blocks, [QK_K // 2])
  415. qh, rest = np.hsplit(rest, [QK_K // 4])
  416. scales, d = np.hsplit(rest, [QK_K // 16])
  417. scales = scales.view(np.int8).astype(np.float32)
  418. d = d.view(np.float16).astype(np.float32)
  419. d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
  420. ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  421. ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
  422. qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
  423. qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))
  424. q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
  425. q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)
  426. return (d * q).reshape((n_blocks, QK_K))
  427. class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0):
  428. @classmethod
  429. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  430. n_blocks = blocks.shape[0]
  431. d = abs(blocks).max(axis=-1, keepdims=True)
  432. with np.errstate(divide="ignore"):
  433. id = np.where(d == 0, 0, 1 / d)
  434. qs = np_roundf(blocks * id)
  435. qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
  436. qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):]
  437. qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
  438. qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1))
  439. qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
  440. qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1))
  441. qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1))
  442. qh = np.sum(qh, axis=-2).reshape((n_blocks, -1))
  443. qs = np.concatenate([qs0, qs1, qh], axis=-1)
  444. qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243
  445. qs = qs.astype(np.uint8)
  446. d = d.astype(np.float16).view(np.uint8)
  447. return np.concatenate([qs, d], axis=-1)
  448. @classmethod
  449. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  450. n_blocks = blocks.shape[0]
  451. qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5])
  452. qh, d = np.hsplit(rest, [QK_K // 64])
  453. d = d.view(np.float16).astype(np.float32)
  454. qs0, qs1 = qs[..., :32], qs[..., 32:]
  455. qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
  456. qs0 = qs0.reshape((n_blocks, -1))
  457. qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
  458. qs1 = qs1.reshape((n_blocks, -1))
  459. qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1))
  460. qh = qh.reshape((n_blocks, -1))
  461. qs = np.concatenate([qs0, qs1, qh], axis=-1)
  462. qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1)
  463. return (d * qs.astype(np.float32))
  464. class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
  465. @classmethod
  466. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  467. n_blocks = blocks.shape[0]
  468. d = abs(blocks).max(axis=-1, keepdims=True)
  469. with np.errstate(divide="ignore"):
  470. id = np.where(d == 0, 0, 1 / d)
  471. qs = np_roundf(blocks * id)
  472. qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
  473. qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
  474. qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :]
  475. qs = qs.reshape((n_blocks, -1))
  476. d = d.astype(np.float16).view(np.uint8)
  477. return np.concatenate([qs, d], axis=-1)
  478. @classmethod
  479. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  480. n_blocks = blocks.shape[0]
  481. qs, d = np.hsplit(blocks, [QK_K // 4])
  482. d = d.view(np.float16).astype(np.float32)
  483. qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
  484. qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1)
  485. return (d * qs.astype(np.float32))
  486. class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
  487. # e2m1 values (doubled)
  488. # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
  489. kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
  490. @staticmethod
  491. # see ggml_e8m0_to_fp32_half in ggml-impl.h
  492. def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
  493. bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
  494. return bits.view(np.float32)
  495. @classmethod
  496. def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  497. n_blocks = blocks.shape[0]
  498. d = abs(blocks).max(axis=-1, keepdims=True)
  499. with np.errstate(divide="ignore"):
  500. e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
  501. d = cls.e8m0_to_fp32_half(e)
  502. kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
  503. errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
  504. best = np.argmin(errs, axis=-1, keepdims=True)
  505. qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
  506. qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
  507. qs = qs.reshape((n_blocks, cls.block_size // 2))
  508. return np.concatenate([e, qs], axis=-1)
  509. @classmethod
  510. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  511. n_blocks = blocks.shape[0]
  512. e, qs = np.hsplit(blocks, [1])
  513. d = cls.e8m0_to_fp32_half(e)
  514. qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
  515. qs = (qs & np.uint8(0x0F)).view(np.int8)
  516. kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
  517. qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
  518. return (d * qs.astype(np.float32))
  519. class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
  520. ksigns: bytes = (
  521. b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
  522. b"\x90\x11\x12\x93\x14\x95\x96\x17\x18\x99\x9a\x1b\x9c\x1d\x1e\x9f"
  523. b"\xa0\x21\x22\xa3\x24\xa5\xa6\x27\x28\xa9\xaa\x2b\xac\x2d\x2e\xaf"
  524. b"\x30\xb1\xb2\x33\xb4\x35\x36\xb7\xb8\x39\x3a\xbb\x3c\xbd\xbe\x3f"
  525. b"\xc0\x41\x42\xc3\x44\xc5\xc6\x47\x48\xc9\xca\x4b\xcc\x4d\x4e\xcf"
  526. b"\x50\xd1\xd2\x53\xd4\x55\x56\xd7\xd8\x59\x5a\xdb\x5c\xdd\xde\x5f"
  527. b"\x60\xe1\xe2\x63\xe4\x65\x66\xe7\xe8\x69\x6a\xeb\x6c\xed\xee\x6f"
  528. b"\xf0\x71\x72\xf3\x74\xf5\xf6\x77\x78\xf9\xfa\x7b\xfc\x7d\x7e\xff"
  529. )
  530. # iq2xxs_grid, but with each byte of the original packed in 2 bits,
  531. # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
  532. grid_shape = (256, 8)
  533. grid_map = (0x08, 0x19, 0x2b)
  534. grid_hex = (
  535. b"00000200050008000a00110014002000220028002a0041004400500058006100"
  536. b"6400800082008a00a20001010401100115014001840198010002020222028202"
  537. b"010404041004210424044004420448046004810484049004a404000502050805"
  538. b"200546056905800591050906100640068406a406000805080808140828084108"
  539. b"440850085208880804094009020a140a01100410101021104010601084109010"
  540. b"951000110811201150115a118011241245120014081420142514491480141815"
  541. b"6215001616160118041810184018811800190519a019511a002002200a204420"
  542. b"6120802082202921482100220222012404241024402456240025412564259026"
  543. b"082820289428442a014004401040184021402440404048405640604081408440"
  544. b"9040004120416141804185410142104248425642684200440844204480449944"
  545. b"124524450046014804481048404845480049584961498249454a904a00500850"
  546. b"1150195020508050885004514251a4519152905492540a550156545600581158"
  547. b"195864584059085a046010604060686000615561186260620064056410651265"
  548. b"84654268008002800a8041808280048118814081118201840484108415844084"
  549. b"608400854685948509864086608602880489118a0490109024904090a1901691"
  550. b"8091459200942294449451958198209902a050a085a009a100a218a450a804a9"
  551. )
  552. @classmethod
  553. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  554. n_blocks = blocks.shape[0]
  555. d, qs = np.hsplit(blocks, [2])
  556. d = d.view(np.float16).astype(np.float32)
  557. qs = qs.view(np.uint32).reshape(n_blocks, -1, 2)
  558. db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25)
  559. db = db.reshape((n_blocks, -1, 1, 1))
  560. # get the sign indices and unpack the bits
  561. signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
  562. ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
  563. signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
  564. signs = np.take_along_axis(ksigns, signs, axis=-1)
  565. signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
  566. signs = signs & np.uint8(0x01)
  567. signs = np.where(signs == 0, np.float32(1), np.float32(-1))
  568. signs = signs.reshape((n_blocks, -1, 4, 8))
  569. assert cls.grid is not None
  570. grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2)
  571. grid = grid.reshape((n_blocks, -1, 4, 8))
  572. return (db * grid * signs).reshape((n_blocks, -1))
  573. class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS):
  574. # iq2xs_grid, but with each byte of the original packed in 2 bits,
  575. # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
  576. grid_shape = (512, 8)
  577. grid_map = (0x08, 0x19, 0x2b)
  578. grid_hex = (
  579. b"00000200050008000a0011001400160019002000220025002800410044004600"
  580. b"49005000520055005800610064008000820085008800910094009900a0000101"
  581. b"04010601090110011201150118011a0121012401400142014501480151015401"
  582. b"6001680181018401900100020202050208021102140220024102440250025502"
  583. b"80028a0201040404060409041004120415041804210424044004420445044804"
  584. b"5104540456046004810484049004000502050505080511051405200541054405"
  585. b"500561058005010604061006260640064206840600080208050808080a081108"
  586. b"14082008250841084408500858088008a008aa08010904091009400981098909"
  587. b"000a200a280a960aa00a01100410061009101010121015101810211024104010"
  588. b"4210451048105110541060106a10811084109010001102110511081111111411"
  589. b"2011411144115011801194119611011204120612101240126012001402140514"
  590. b"0814111414142014411444144914501464148014011504151015401500161416"
  591. b"49160118041810181218401854188618001905196619511aa91a002002200520"
  592. b"08200a201120142020204120442050208020a020012104211021402148216521"
  593. b"002222228022a82201240424102429244024002541255225992501261a26a626"
  594. b"002808280a28202855288828a22868299029082a202a822a882a8a2a01400440"
  595. b"0640094010401240154018402140244040404240454048404a40514054406040"
  596. b"6540814084409040004102410541084111411441204141414441504180418541"
  597. b"a241014204421042124229424042004402440544084411441444194420444144"
  598. b"4444504480449444014504451045244540459a4500460a464446504601480448"
  599. b"1048404845485448624800491149444950496949044a00500250055008501150"
  600. b"145020502850415044505050805001510451105115514051425100524452aa52"
  601. b"0154045410542154405460548154a154005508558055885521566856a1560058"
  602. b"14584158505899581a5940594259855a0160046010604060546062608660a960"
  603. b"006124624a62926200641664106540654565a46501686a682569066a546a626a"
  604. b"00800280058008801180148020802a8041804480508080808280a880aa800181"
  605. b"0481068110814081518159810082208280828282a082a8820184048410841284"
  606. b"158440846084898400854485a58518866a860088088825885a8880888288a888"
  607. b"0689228a808a888a968aa88a0190049010904090569084900091229164915692"
  608. b"89920094059444945094589429959095929541965198a6984999159a609a00a0"
  609. b"02a008a00aa020a02aa0a0a051a159a1a6a100a202a208a22aa280a2a0a240a4"
  610. b"95a465a698a60aa820a822a828a8a0a8a8a804a984a986a928aa2aaa91aaaaaa"
  611. )
  612. @classmethod
  613. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  614. n_blocks = blocks.shape[0]
  615. d, rest = np.hsplit(blocks, [2])
  616. qs, scales = np.hsplit(rest, [2 * QK_K // 8])
  617. d = d.view(np.float16).astype(np.float32)
  618. qs = qs.view(np.uint16)
  619. scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
  620. scales = (scales & 0x0F).reshape((n_blocks, -1))
  621. db = d * (np.float32(0.5) + scales) * np.float32(0.25)
  622. db = db.reshape((n_blocks, -1, 1, 1))
  623. # get the sign indices and unpack the bits
  624. signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128)
  625. signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1)
  626. signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
  627. signs = signs & np.uint8(0x01)
  628. signs = np.where(signs == 0, np.float32(1), np.float32(-1))
  629. signs = signs.reshape((n_blocks, -1, 2, 8))
  630. assert cls.grid is not None
  631. grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2)
  632. grid = grid.reshape((n_blocks, -1, 2, 8))
  633. return (db * grid * signs).reshape((n_blocks, -1))
  634. class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S):
  635. # iq2s_grid, but with each byte of the original packed in 2 bits,
  636. # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
  637. grid_shape = (1024, 8)
  638. grid_map = (0x08, 0x19, 0x2b)
  639. grid_hex = (
  640. b"00000200050008000a0011001400160019002000220025002800410044004600"
  641. b"490050005200550058006100640066006900800082008500880091009400a000"
  642. b"a500aa0001010401060109011001120115011801210124014001420145014801"
  643. b"510154015601590160016501680181018401900192019501a101a40100020202"
  644. b"050208021102140220022a02410244024602490250025502800285028a029402"
  645. b"a202010404040604090410041204150418042104240426042904400442044504"
  646. b"48044a0451045404560459046004620465048104840486048904900495049804"
  647. b"a104a40400050205050508050a05110514051605190520052505280541054405"
  648. b"46054905500552055505580561056405800582058505880591059405a0050106"
  649. b"0406060609061006150640064506480651065406600681068406900600080208"
  650. b"050808081108140816081908200825082a084108440846084908500852085508"
  651. b"580861086408800885089408aa08010904091009120915091809210940094509"
  652. b"480951095409600981099009000a110a140a220a280a2a0a500a990a01100410"
  653. b"0610091010101210151018102110241026104010421045104810511054105610"
  654. b"59106010621065106810811084108610901095109810a110a410001102110511"
  655. b"08110a1111111411161119112011221125112811411144114611491150115211"
  656. b"5511581161116411801182118511881191119411011204120912101215122112"
  657. b"2412401245125112541281128412901200140214051408141114141416141914"
  658. b"2014251428144114441446144914501452145514581461146414801482148514"
  659. b"881491149414a014011504150615091510151215151518152115241540154215"
  660. b"4515481551155415601581158415901500160516081611161416201641164416"
  661. b"50168016aa160118041806180918101815181818211840184218451848185118"
  662. b"541860188118841800190219051908191119141920194119441950196919a219"
  663. b"041a101a401a561a00200220052008201120142016201920202025202a204120"
  664. b"4420502052205520642080208a209420aa200121042110211221152121214021"
  665. b"4221452151215421602181218421902100220a22222228222a22442250228822"
  666. b"8a22a82201240424062409241024152418242124242440244224452448245124"
  667. b"5424602481248424902400250525082511251425202541254425502566258025"
  668. b"0126042610264026592600280528112814284128442850288a28aa2801290429"
  669. b"102995290a2a222a642a882a8a2a014004400640094010401240154018401a40"
  670. b"21402440264040404240454048404a4051405440564059406040624065408140"
  671. b"8440904095409840a140a4400041024105410841114114411641194120412241"
  672. b"2541414144414641494150415241554158416141644180418241854188419141"
  673. b"9441a04101420442104212421542184224424042454248425142544260428142"
  674. b"844200440244054408440a441144144416441944204422442544284441444444"
  675. b"46444944504452445544584461446444804482448544884491449444a0440145"
  676. b"0445064509451045124515451845214524454045424545454845514554456045"
  677. b"6a4581458445904500460246054608461146144620464146444650468046a546"
  678. b"0148044809481048124815481848214824484048424845484848514854486048"
  679. b"84489048004902490549084911491449204941494449504980499649014a044a"
  680. b"104a404a00500250055008501150145016501950205022502550285041504450"
  681. b"4650495050505250555058506150645080508250855088509150945001510451"
  682. b"0651095110511251155118512151245140514251455148515151545160518151"
  683. b"8451905100520552085211521452205241524452505269528052015404540654"
  684. b"0954105412541554185421542454405442544554485451545454605481548454"
  685. b"9054005502550555085511551455205541554455505580550156045610562656"
  686. b"405600580258055808581158145820584158445850585a588058015904591059"
  687. b"4059005a195a855aa85a01600460066010601260156018602160246040604560"
  688. b"4860516054606060846090600061026105610861116114612061416144615061"
  689. b"806199610462106240625662a162006405640864116414642064416444645064"
  690. b"806401650465106540654a656865926500669466016804681068656898680069"
  691. b"2a69426aa16a0080028005800880118014801980208025804180448050805280"
  692. b"5580588061808080858091809480018104810981108112811581188121812481"
  693. b"408142814581488151815481818184819081a981008205820a82118214824182"
  694. b"4482508201840484068409841084128415841884218440844284458448845184"
  695. b"5484608481848484908400850285058508851185148520854185448550858085"
  696. b"8a85018604861086298640860088058811881488418844885088a28801890489"
  697. b"40896589228a588a5a8a828aa28a019004900990109012901590189024904090"
  698. b"4290459048905190549060908190849090900091059111911491419144915091"
  699. b"5a910192049210924092a6920094029405940894119414942094419444945094"
  700. b"8094969401950495109540959895a19500964696649601980498109826984098"
  701. b"a998009949995299909a00a005a00aa014a022a02aa041a044a050a0a2a0aaa0"
  702. b"40a165a102a20aa222a228a22aa282a288a28aa2a8a201a404a410a440a489a4"
  703. b"a4a400a519a551a60aa828a8a2a854a986a908aa0aaa20aa22aa28aa88aaaaaa"
  704. )
  705. @classmethod
  706. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  707. n_blocks = blocks.shape[0]
  708. d, rest = np.hsplit(blocks, [2])
  709. qs, rest = np.hsplit(rest, [QK_K // 8])
  710. signs, rest = np.hsplit(rest, [QK_K // 8])
  711. qh, scales = np.hsplit(rest, [QK_K // 32])
  712. d = d.view(np.float16).astype(np.float32)
  713. scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
  714. scales = (scales & 0x0F).reshape((n_blocks, -1))
  715. db = d * (np.float32(0.5) + scales) * np.float32(0.25)
  716. db = db.reshape((n_blocks, -1, 1, 1))
  717. # unpack the sign bits
  718. signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
  719. signs = signs & np.uint8(0x01)
  720. signs = np.where(signs == 0, np.float32(1), np.float32(-1))
  721. signs = signs.reshape((n_blocks, -1, 2, 8))
  722. qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
  723. qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1))
  724. assert cls.grid is not None
  725. grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
  726. grid = grid.reshape((n_blocks, -1, 2, 8))
  727. return (db * grid * signs).reshape((n_blocks, -1))
  728. class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS):
  729. grid_shape = (256, 4)
  730. grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e)
  731. grid_hex = (
  732. b"0000020004001100130017002000220031004200730075000101030110011201"
  733. b"2101250130013201410154017001000202020402110220022202310233023702"
  734. b"5102570275020103070310031203250370031304370444045704730475040105"
  735. b"0705320552053506640610071407160743076107011003101010121021102310"
  736. b"3010321034104710501000110211111120112211011203121012121221123012"
  737. b"7212001302132013311346136613011405145014201524154615711505162217"
  738. b"4017002002201120132020202220262031204220012103210521102112212121"
  739. b"3021632167217021002202221122172220222222372240225522012310231423"
  740. b"7023742335245324032527254125742501270327162745270130103012302130"
  741. b"2330503065307230003102312031313144314631013203321032253252327232"
  742. b"1133333330344734723400350635223555351436363663363337603704401740"
  743. b"3540374053405740744120423742404260426642074345430444514464442545"
  744. b"4345704505471047124730471250415070500051065126515551145232527252"
  745. b"0253535310542354275472540255315550562457425724604460466064602161"
  746. b"6161176264623063366344640565526533660367216703700570077010703270"
  747. b"5270267140711272457252720073157333736073217441740075027524753076"
  748. )
  749. @classmethod
  750. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  751. n_blocks = blocks.shape[0]
  752. d, rest = np.hsplit(blocks, [2])
  753. qs, scales = np.hsplit(rest, [QK_K // 4])
  754. d = d.view(np.float16).astype(np.float32)
  755. scales = scales.view(np.uint32)
  756. db = d * (np.float32(0.5) + (scales >> 28).astype(np.float32)) * np.float32(0.5)
  757. db = db.reshape((n_blocks, -1, 1, 1))
  758. # get the sign indices and unpack the bits
  759. signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
  760. ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
  761. signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
  762. signs = np.take_along_axis(ksigns, signs, axis=-1)
  763. signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
  764. signs = signs & np.uint8(0x01)
  765. signs = np.where(signs == 0, np.float32(1), np.float32(-1))
  766. signs = signs.reshape((n_blocks, -1, 4, 8))
  767. assert cls.grid is not None
  768. grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
  769. grid = grid.reshape((n_blocks, -1, 4, 8))
  770. return (db * grid * signs).reshape((n_blocks, -1))
  771. class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S):
  772. grid_shape = (512, 4)
  773. grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f)
  774. grid_hex = (
  775. b"0000010002000500070010001100120014001600200021002500330040004200"
  776. b"4500470051005300600062007100740077000001010102010401100111011501"
  777. b"2001230127013101350144016101650172010002010205020702100213021602"
  778. b"2102250230023402420245024702510253027002730203031103150320032203"
  779. b"3103330336034403500352036703710375030004130417042104240432044004"
  780. b"4304510470040205040520052205260533054105450547056605730506061106"
  781. b"1306310652067106000702070407200722072607330750075407001001100210"
  782. b"0410101011101310151017102010221031103410361054105610611072100011"
  783. b"0111031106111011141121113011331141115011521170117611001212121512"
  784. b"1712201224123212401243125512601272120113041307131013131321132713"
  785. b"3013341341136213701303140514121414143114331442144614501454140115"
  786. b"1015131521153015321551152016241627164416461601170317101712172117"
  787. b"3517411762177017002001200320052007201020122014201620212023202720"
  788. b"3020322041204320452050205220672070207320752000210221102113211721"
  789. b"2221252131213421422151210122042207222122232230223722412253225722"
  790. b"7122742200230223052311232223242331233323422350236623012407242024"
  791. b"2324322435244124722475240425112522253725402553257025002602260726"
  792. b"2126552661260527112726273027432750270230113013301530173022303130"
  793. b"3330353042304430473051306330713001310331053114312131233140316031"
  794. b"7231763100321232203232323432503201331033143321332333273330334133"
  795. b"4333473355337333033411341634223431345234603464340135103512352535"
  796. b"3235443556357335163641360137033720372237353700400440124020402440"
  797. b"2740324041405040704002410741114113412241304135414341514155410142"
  798. b"0342104215422142334240425742624270420443114313432043224331433543"
  799. b"0044024424443744404471440545074521456245134634466046104715473047"
  800. b"4347514702501050145022504050445047505250665074500151035105511251"
  801. b"2151325172510052115223523052365253520253075310532753445351536553"
  802. b"7353015404542054325446541255265551555355425602570457225711601360"
  803. b"1560316033606060006120612761646112623462426255626262706200631463"
  804. b"2163406325644364626400650365346560650566406611671367007004700770"
  805. b"2070227036704070547062700271117124714371457101720472107216722172"
  806. b"3072517202733273357353730174057413742074507422754275027631760077"
  807. )
  808. @classmethod
  809. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  810. n_blocks = blocks.shape[0]
  811. d, rest = np.hsplit(blocks, [2])
  812. qs, rest = np.hsplit(rest, [QK_K // 4])
  813. qh, rest = np.hsplit(rest, [QK_K // 32])
  814. signs, scales = np.hsplit(rest, [QK_K // 8])
  815. d = d.view(np.float16).astype(np.float32)
  816. scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
  817. scales = (scales & 0x0F).reshape((n_blocks, -1))
  818. db = d * (1 + 2 * scales)
  819. db = db.reshape((n_blocks, -1, 1, 1))
  820. # unpack the sign bits
  821. signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
  822. signs = signs & np.uint8(0x01)
  823. signs = np.where(signs == 0, np.float32(1), np.float32(-1))
  824. signs = signs.reshape((n_blocks, -1, 4, 8))
  825. qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8)
  826. qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1))
  827. qs = qs.astype(np.uint16) | (qh << 8)
  828. assert cls.grid is not None
  829. grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
  830. grid = grid.reshape((n_blocks, -1, 4, 8))
  831. return (db * grid * signs).reshape((n_blocks, -1))
  832. class IQ1_S(__Quant, qtype=GGMLQuantizationType.IQ1_S):
  833. # iq1s_grid, with each byte packed into 2 bits
  834. # -1, 0, 1 <=> 0, 1, 2
  835. grid_shape = (2048, 8)
  836. grid_map = (-1, 0, 1)
  837. grid_hex = (
  838. b"00000200050008000a00110015002000220028002a0045005100540056006500"
  839. b"8000820088008a009500a000a200a800aa000401050111011401160119011a01"
  840. b"2501410146014901520155015a0161016401660168018501910194019601a501"
  841. b"0002020208020a0215022002220228022a024502510259026402690280028202"
  842. b"88028a02910295029902a002a202a802aa021104140416042504410449045504"
  843. b"5a046404650491049904a5040105040505050605150518051a05290540054505"
  844. b"4a0550055105540555055605590560056205650568056a058105910595059805"
  845. b"9a05a105a405a505a605a9051406190641064406500652065506580660066106"
  846. b"6606690685069106940699060008020808080a0815082008220828082a084508"
  847. b"5108560865088008820888088a089508a008a208a808aa080509110914091909"
  848. b"2409250941095009510955096109640969099109940996099909a509000a020a"
  849. b"080a0a0a150a200a220a280a2a0a450a510a590a610a650a800a820a850a880a"
  850. b"8a0a950aa00aa20aa80aaa0a1010111014101910241025104110441050105510"
  851. b"58106110641065106910911094109610a110a510011104110611091110111211"
  852. b"1511181121112411291145114a11501151115211541155115611591160116511"
  853. b"841192119511a111a41111121412161225124012461249125212551258125a12"
  854. b"641266128512911294129612a512011406140914141415141814191421142614"
  855. b"41144514461448144a1451145414551456145914621465146814841489149014"
  856. b"94149514981499149a14a114a414a514a914021505150a151115141515151615"
  857. b"191520152215251528152a154115441545154615511552155415551556155915"
  858. b"5a1561156415651566156915801582158415851588158a159015911594159515"
  859. b"961599159a15a015a215a51501160416051606161516161618161a1621162616"
  860. b"401642164416451648164a165116551656165816591661166416651668166916"
  861. b"6a1686168a1692169516a416a916111816182518411844184618491850185518"
  862. b"58185a1860186118641866186918851891189418a5181019121915191a192119"
  863. b"25194219441945194819511954195519561959195a19601965196a1989199119"
  864. b"921995199819a119a619a919091a161a241a261a441a461a491a501a521a551a"
  865. b"581a611a661a691a851a911a961a9a1a0020022008200a201520202022202520"
  866. b"28202a20452051205920612065208020822088208a209520a020a220a520a820"
  867. b"aa2005211121142119212521422144214921552158215a216121642165216621"
  868. b"8521902196219921a521012208220a22112215222022222228222a2245225122"
  869. b"562259226522812288228a2291229522a022a222a822aa220524142416241924"
  870. b"252444244524462449245224552458245a2466248524912494249924a124a524"
  871. b"0925152521252925402545254825512554255525592562256525682589259025"
  872. b"9425952598259a25a125a425a625a92505261026122619262526412649265526"
  873. b"6026612669268426862690269a260028022808280a2815282028222828282a28"
  874. b"45285128542865288028822888288a28a028a228a828aa280929112914291929"
  875. b"2529462949295229552961296429662969298529902996299929a429a529002a"
  876. b"022a082a0a2a202a222a282a2a2a452a512a562a592a652a802a822a882a8a2a"
  877. b"952aa02aa22aa82aaa2a054011401640254049405240554058405a4061406440"
  878. b"664094409940a140a6400041014104410641094112411541164118411a412141"
  879. b"26412941454148414a41514154415541564159415a41654168416a4181418441"
  880. b"8641904192419541a041a141a241054211421442164225424142524255425a42"
  881. b"6442694289429442a5420144154419442944454448444a445144544455445644"
  882. b"61446244654468446a44814486448944904492449544a044a144a94401450245"
  883. b"05450a4511451445154516451945204525452a45414544454545464549455045"
  884. b"5145544555455645584559456145644565456645694582458445854588459145"
  885. b"94459545964599459a45a545a845aa450146054609461446154618461a462146"
  886. b"2446294640464246454648465046514652465546564659466246654668468146"
  887. b"85468a4694469546a146a446a6460548114815481a4825484248494850485548"
  888. b"5848614864486648694885489148944896489948a5480149054906490a491049"
  889. b"144915491849214924492649404945494a495149524954495549564959496049"
  890. b"6249654966496a49864989499249954996499849a149a449a649a949164a444a"
  891. b"464a494a554a584a5a4a644a694a944aa54a0150045005500650095012501550"
  892. b"1a50215024502950405045504850515054505550565059506550685086508950"
  893. b"95509850a050a150a650a9500551085109510a51115114511551165118511951"
  894. b"20512551265128512a5141514451455146514951505151515251545155515651"
  895. b"585159515a51615164516551665169518251855191519451955196519951a051"
  896. b"a551aa5101520652125215521a5221522452425245524a525152545255525652"
  897. b"595262526552855290529252955299529a52a452045405541154145415541654"
  898. b"185419542154255428542a54415444544554465449544a545054515454545554"
  899. b"5654585459545a54615462546454655466546954805488548a54915494549554"
  900. b"96549954a154a454a554aa540155025504550555065509551055115512551455"
  901. b"1555165519551a55215524552555265529554055415542554455455546554855"
  902. b"4955505551555255545555555655585559555a55605561556455655566556855"
  903. b"69556a5581558455855589558a559055915594559555965598559955a155a455"
  904. b"a555a655a9550056015602560456065608560956115614561556185619562056"
  905. b"2156225624562556265628562956415645564656485649564a56505651565256"
  906. b"545655565656585659565a566156645665566956825685568656885689568a56"
  907. b"915695569a56a256a556a656a856a95604580558065809581058155818582158"
  908. b"2a58455848584a58515854585558565858585958605862586458655882588958"
  909. b"9058925895589858a158a9580159025905590a59115914591559165919592559"
  910. b"41594459455946594959505951595259545955595659585959595a5961596459"
  911. b"655966596959815985598959915994599559965998599959a559045a085a155a"
  912. b"1a5a205a255a265a295a455a485a495a515a555a565a585a595a625a655a685a"
  913. b"6a5a815a8a5a925a955a965a985a9a5aa15a0560146016601960256044605060"
  914. b"5560566058605a60616064606660696081609660a56001610461066109611261"
  915. b"15612161226126612961456149615161556156615961656166616a6184618a61"
  916. b"92619561a161a661a96111621662196240624162466255625662586260628562"
  917. b"91629662a56211641264156416641a6421642664296440644264456448644a64"
  918. b"516454645564566459645a646064626465648464856489649064926494649564"
  919. b"966498649a64a164a464a964056508650a651165156516651965446545654665"
  920. b"496550655165546555655665596561656465656566656965866589658a659165"
  921. b"9565966599659a65a265a565a665a86502660966156620662666286629664066"
  922. b"456648664a66516654665566566658665a666066656668668066826685668a66"
  923. b"9466966698669966a066a466a666aa661668196825684168526855685a686168"
  924. b"6968856891689868a66801690469106915692169246926692969406941694569"
  925. b"4669486951695469556956695969606965696a69826984698a699569a169a469"
  926. b"a569a969116a166a186a416a446a496a506a556a586a5a6a646a656a696a866a"
  927. b"946a986a9a6aa66a0080028008800a802080228028802a804580508051805480"
  928. b"5680598065808080828088808a809580a080a280a880aa800581118114811681"
  929. b"1981258141814481498150815281558156815881598164816681698185818981"
  930. b"948196819981a5810082028208820a8215822082228228822a82518254825982"
  931. b"65828082828288828a829582a082a282a882aa82148419844184448451845584"
  932. b"5a846184648469849484998401850985128515851a8526852985408541854585"
  933. b"4885518554855585568559855a856585668568856a8581858485868589859085"
  934. b"928595859885a68511861686198625864186448649864a865086558659865a86"
  935. b"618666866a86858691869a86a4860088028808880a8815882088228828882a88"
  936. b"41884588518854885988658869888088828888888a889588a088a288a888aa88"
  937. b"05890689118914891689258941894489468949895089528955895a8961896489"
  938. b"858996899989a589008a028a088a0a8a158a208a228a288a2a8a458a518a548a"
  939. b"568a808a828a888a8a8a958aa08aa28aa88aaa8a059011901690189019902590"
  940. b"419046904990559058905a9069906a9085909190949096909990a59001910491"
  941. b"069109911091159118911a912191249126912991409145915091519154915591"
  942. b"569159916291659184918691929195919891a191a491a691a991059211921492"
  943. b"19922592449246924992509252925592589266926992859294929692a9920194"
  944. b"04940694109415941894269440944a9451945494559456945894599460946194"
  945. b"62946594849486949294949495949894a194a9940095059508950a9510951195"
  946. b"14951595169519952195259529952a9541954495459546954995509551955295"
  947. b"549555955695589559955a956195649565956695699581958595889591959295"
  948. b"94959595969599959a95a095a295a595a895aa95019604961096159619962096"
  949. b"2696299645964896499651965296559656965996659668968296849689968a96"
  950. b"929694969596a496a696a9960598169819982598419846985098529855985698"
  951. b"5a98649865988598919896989998a59804990699099910991299159918991a99"
  952. b"209921992499269940994299459948994a995199549955995699599962996599"
  953. b"66996a99819984999099929995999a99a199a699059a159a259a449a469a499a"
  954. b"509a559a589a619a859a919a949a959a969a00a002a008a00aa015a020a022a0"
  955. b"28a02aa045a051a054a056a059a080a082a088a08aa095a0a0a0a2a0a8a0aaa0"
  956. b"05a109a111a114a116a119a11aa146a149a151a155a158a15aa161a164a185a1"
  957. b"90a192a196a199a102a208a20aa210a219a222a228a22aa245a251a256a259a2"
  958. b"65a280a282a288a28aa295a2a0a2a2a2a8a2aaa219a425a441a444a450a454a4"
  959. b"55a458a45aa461a465a466a468a469a485a406a509a510a512a515a518a526a5"
  960. b"29a542a545a551a554a555a556a559a565a56aa581a584a585a586a589a592a5"
  961. b"95a598a505a611a616a61aa621a625a644a646a64aa652a655a656a658a660a6"
  962. b"62a686a690a695a696a699a6a1a6a4a6a6a600a802a808a80aa820a822a828a8"
  963. b"2aa851a854a856a859a880a882a888a88aa895a8a0a8a2a8a8a8aaa805a914a9"
  964. b"19a921a925a941a950a955a95aa961a966a969a990a996a900aa02aa08aa0aaa"
  965. b"20aa22aa28aa2aaa51aa54aa56aa80aa82aa88aa8aaa95aaa0aaa2aaa8aaaaaa"
  966. )
  967. delta = np.float32(0.125)
  968. @classmethod
  969. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  970. n_blocks = blocks.shape[0]
  971. d, rest = np.hsplit(blocks, [2])
  972. qs, qh = np.hsplit(rest, [QK_K // 8])
  973. d = d.view(np.float16).astype(np.float32)
  974. qh = qh.view(np.uint16)
  975. dl = d * (2 * ((qh >> 12) & 7) + 1)
  976. dl = dl.reshape((n_blocks, -1, 1, 1))
  977. delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta)
  978. delta = delta.reshape((n_blocks, -1, 1, 1))
  979. qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
  980. qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1))
  981. assert cls.grid is not None
  982. grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
  983. grid = grid.reshape((n_blocks, -1, 4, 8))
  984. return (dl * (grid + delta)).reshape((n_blocks, -1))
  985. class IQ1_M(__Quant, qtype=GGMLQuantizationType.IQ1_M):
  986. grid_shape = IQ1_S.grid_shape
  987. grid_map = IQ1_S.grid_map
  988. grid_hex = IQ1_S.grid_hex
  989. delta = IQ1_S.delta
  990. # Okay *this* type is weird. It's the only one which stores the f16 scales in multiple parts.
  991. @classmethod
  992. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  993. n_blocks = blocks.shape[0]
  994. qs, rest = np.hsplit(blocks, [QK_K // 8])
  995. qh, scales = np.hsplit(rest, [QK_K // 16])
  996. # The f16 scale is packed across multiple bytes
  997. scales = scales.view(np.uint16)
  998. d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4))
  999. d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3]
  1000. d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1))
  1001. scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
  1002. scales = (scales & 0x07).reshape((n_blocks, -1))
  1003. dl = d * (2 * scales + 1)
  1004. dl = dl.reshape((n_blocks, -1, 2, 1, 1))
  1005. qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
  1006. qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1))
  1007. delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta)
  1008. delta = delta.reshape((n_blocks, -1, 2, 2, 1))
  1009. assert cls.grid is not None
  1010. grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
  1011. grid = grid.reshape((n_blocks, -1, 2, 2, 8))
  1012. return (dl * (grid + delta)).reshape((n_blocks, -1))
  1013. class IQ4_NL(__Quant, qtype=GGMLQuantizationType.IQ4_NL):
  1014. kvalues = (-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113)
  1015. @classmethod
  1016. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  1017. n_blocks = blocks.shape[0]
  1018. d, qs = np.hsplit(blocks, [2])
  1019. d = d.view(np.float16).astype(np.float32)
  1020. qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  1021. qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1))
  1022. kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
  1023. qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1))
  1024. return (d * qs)
  1025. class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS):
  1026. @classmethod
  1027. def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
  1028. n_blocks = blocks.shape[0]
  1029. d, rest = np.hsplit(blocks, [2])
  1030. scales_h, rest = np.hsplit(rest, [2])
  1031. scales_l, qs = np.hsplit(rest, [QK_K // 64])
  1032. d = d.view(np.float16).astype(np.float32)
  1033. scales_h = scales_h.view(np.uint16)
  1034. scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
  1035. scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1))
  1036. scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
  1037. scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03)
  1038. scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
  1039. dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))
  1040. qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
  1041. qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)
  1042. kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
  1043. qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
  1044. return (dl * qs).reshape((n_blocks, -1))