gguf_reader.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #
  2. # GGUF file reading/modification support. For API usage information,
  3. # please see the files scripts/ for some fairly simple examples.
  4. #
  5. from __future__ import annotations
  6. import logging
  7. import os
  8. import sys
  9. from collections import OrderedDict
  10. from typing import Any, Literal, NamedTuple, TypeVar, Union
  11. import numpy as np
  12. import numpy.typing as npt
  13. from .quants import quant_shape_to_byte_shape
  14. if __name__ == "__main__":
  15. from pathlib import Path
  16. # Allow running file in package as a script.
  17. sys.path.insert(0, str(Path(__file__).parent.parent))
  18. from gguf.constants import (
  19. GGML_QUANT_SIZES,
  20. GGUF_DEFAULT_ALIGNMENT,
  21. GGUF_MAGIC,
  22. GGUF_VERSION,
  23. GGMLQuantizationType,
  24. GGUFValueType,
  25. GGUFEndian,
  26. )
  27. logger = logging.getLogger(__name__)
  28. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  29. class ReaderField(NamedTuple):
  30. # Offset to start of this field.
  31. offset: int
  32. # Name of the field (not necessarily from file data).
  33. name: str
  34. # Data parts. Some types have multiple components, such as strings
  35. # that consist of a length followed by the string data.
  36. parts: list[npt.NDArray[Any]] = []
  37. # Indexes into parts that we can call the actual data. For example
  38. # an array of strings will be populated with indexes to the actual
  39. # string data.
  40. data: list[int] = [-1]
  41. types: list[GGUFValueType] = []
  42. def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
  43. if self.types:
  44. to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
  45. main_type = self.types[0]
  46. if main_type == GGUFValueType.ARRAY:
  47. sub_type = self.types[-1]
  48. if sub_type == GGUFValueType.STRING:
  49. indices = self.data[index_or_slice]
  50. if isinstance(index_or_slice, int):
  51. return to_string(self.parts[indices]) # type: ignore
  52. else:
  53. return [to_string(self.parts[idx]) for idx in indices] # type: ignore
  54. else:
  55. # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
  56. # Check if it's unsafe to perform slice optimization on data
  57. # if any(True for idx in self.data if len(self.parts[idx]) != 1):
  58. # optim_slice = slice(None)
  59. # else:
  60. # optim_slice = index_or_slice
  61. # index_or_slice = slice(None)
  62. # if isinstance(optim_slice, int):
  63. # return self.parts[self.data[optim_slice]].tolist()[0]
  64. # else:
  65. # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
  66. if isinstance(index_or_slice, int):
  67. return self.parts[self.data[index_or_slice]].tolist()[0]
  68. else:
  69. return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
  70. if main_type == GGUFValueType.STRING:
  71. return to_string(self.parts[-1])
  72. else:
  73. return self.parts[-1].tolist()[0]
  74. return None
  75. class ReaderTensor(NamedTuple):
  76. name: str
  77. tensor_type: GGMLQuantizationType
  78. shape: npt.NDArray[np.uint32]
  79. n_elements: int
  80. n_bytes: int
  81. data_offset: int
  82. data: npt.NDArray[Any]
  83. field: ReaderField
  84. class GGUFReader:
  85. # I - same as host, S - swapped
  86. byte_order: Literal['I', 'S'] = 'I'
  87. alignment: int = GGUF_DEFAULT_ALIGNMENT
  88. data_offset: int
  89. # Note: Internal helper, API may change.
  90. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  91. GGUFValueType.UINT8: np.uint8,
  92. GGUFValueType.INT8: np.int8,
  93. GGUFValueType.UINT16: np.uint16,
  94. GGUFValueType.INT16: np.int16,
  95. GGUFValueType.UINT32: np.uint32,
  96. GGUFValueType.INT32: np.int32,
  97. GGUFValueType.FLOAT32: np.float32,
  98. GGUFValueType.UINT64: np.uint64,
  99. GGUFValueType.INT64: np.int64,
  100. GGUFValueType.FLOAT64: np.float64,
  101. GGUFValueType.BOOL: np.bool_,
  102. }
  103. def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
  104. self.data = np.memmap(path, mode = mode)
  105. offs = 0
  106. # Check for GGUF magic
  107. if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
  108. raise ValueError('GGUF magic invalid')
  109. offs += 4
  110. # Check GGUF version
  111. temp_version = self._get(offs, np.uint32)
  112. if temp_version[0] & 65535 == 0:
  113. # If we get 0 here that means it's (probably) a GGUF file created for
  114. # the opposite byte order of the machine this script is running on.
  115. self.byte_order = 'S'
  116. temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
  117. version = temp_version[0]
  118. if version not in READER_SUPPORTED_VERSIONS:
  119. raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
  120. if sys.byteorder == "little":
  121. # Host is little endian
  122. host_endian = GGUFEndian.LITTLE
  123. swapped_endian = GGUFEndian.BIG
  124. else:
  125. # Sorry PDP or other weird systems that don't use BE or LE.
  126. host_endian = GGUFEndian.BIG
  127. swapped_endian = GGUFEndian.LITTLE
  128. self.endianess = swapped_endian if self.byte_order == "S" else host_endian
  129. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  130. self.tensors: list[ReaderTensor] = []
  131. offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
  132. # Check tensor count and kv count
  133. temp_counts = self._get(offs, np.uint64, 2)
  134. offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
  135. offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
  136. tensor_count, kv_count = temp_counts
  137. offs = self._build_fields(offs, kv_count)
  138. # Build Tensor Info Fields
  139. offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
  140. new_align = self.fields.get('general.alignment')
  141. if new_align is not None:
  142. if new_align.types != [GGUFValueType.UINT32]:
  143. raise ValueError('Bad type for general.alignment field')
  144. self.alignment = new_align.parts[-1][0]
  145. padding = offs % self.alignment
  146. if padding != 0:
  147. offs += self.alignment - padding
  148. self.data_offset = offs
  149. self._build_tensors(offs, tensors_fields)
  150. _DT = TypeVar('_DT', bound = npt.DTypeLike)
  151. # Fetch a key/value metadata field by key.
  152. def get_field(self, key: str) -> Union[ReaderField, None]:
  153. return self.fields.get(key, None)
  154. # Fetch a tensor from the list by index.
  155. def get_tensor(self, idx: int) -> ReaderTensor:
  156. return self.tensors[idx]
  157. def _get(
  158. self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
  159. ) -> npt.NDArray[Any]:
  160. count = int(count)
  161. itemsize = int(np.empty([], dtype = dtype).itemsize)
  162. end_offs = offset + itemsize * count
  163. arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
  164. return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
  165. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  166. if field.name in self.fields:
  167. # TODO: add option to generate error on duplicate keys
  168. # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
  169. logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
  170. self.fields[field.name + '_{}'.format(field.offset)] = field
  171. else:
  172. self.fields[field.name] = field
  173. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  174. def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  175. slen = self._get(offset, np.uint64)
  176. return slen, self._get(offset + 8, np.uint8, slen[0])
  177. def _get_field_parts(
  178. self, orig_offs: int, raw_type: int,
  179. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  180. offs = orig_offs
  181. types: list[GGUFValueType] = []
  182. gtype = GGUFValueType(raw_type)
  183. types.append(gtype)
  184. # Handle strings.
  185. if gtype == GGUFValueType.STRING:
  186. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  187. size = sum(int(part.nbytes) for part in sparts)
  188. return size, sparts, [1], types
  189. # Check if it's a simple scalar type.
  190. nptype = self.gguf_scalar_to_np.get(gtype)
  191. if nptype is not None:
  192. val = self._get(offs, nptype)
  193. return int(val.nbytes), [val], [0], types
  194. # Handle arrays.
  195. if gtype == GGUFValueType.ARRAY:
  196. raw_itype = self._get(offs, np.uint32)
  197. offs += int(raw_itype.nbytes)
  198. alen = self._get(offs, np.uint64)
  199. offs += int(alen.nbytes)
  200. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  201. data_idxs: list[int] = []
  202. # FIXME: Handle multi-dimensional arrays properly instead of flattening
  203. for idx in range(alen[0]):
  204. curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
  205. if idx == 0:
  206. types += curr_types
  207. idxs_offs = len(aparts)
  208. aparts += curr_parts
  209. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  210. offs += curr_size
  211. return offs - orig_offs, aparts, data_idxs, types
  212. # We can't deal with this one.
  213. raise ValueError(f'Unknown/unhandled field type {gtype}')
  214. def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
  215. offs = orig_offs
  216. # Get Tensor Name
  217. name_len, name_data = self._get_str(offs)
  218. offs += int(name_len.nbytes + name_data.nbytes)
  219. # Get Tensor Dimensions Count
  220. n_dims = self._get(offs, np.uint32)
  221. offs += int(n_dims.nbytes)
  222. # Get Tensor Dimension Array
  223. dims = self._get(offs, np.uint64, n_dims[0])
  224. offs += int(dims.nbytes)
  225. # Get Tensor Encoding Scheme Type
  226. raw_dtype = self._get(offs, np.uint32)
  227. offs += int(raw_dtype.nbytes)
  228. # Get Tensor Offset
  229. offset_tensor = self._get(offs, np.uint64)
  230. offs += int(offset_tensor.nbytes)
  231. return ReaderField(
  232. orig_offs,
  233. str(bytes(name_data), encoding = 'utf-8'),
  234. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  235. [1, 3, 4, 5],
  236. )
  237. def _build_fields(self, offs: int, count: int) -> int:
  238. for _ in range(count):
  239. orig_offs = offs
  240. kv_klen, kv_kdata = self._get_str(offs)
  241. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  242. raw_kv_type = self._get(offs, np.uint32)
  243. offs += int(raw_kv_type.nbytes)
  244. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  245. idxs_offs = len(parts)
  246. field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
  247. parts += field_parts
  248. self._push_field(ReaderField(
  249. orig_offs,
  250. str(bytes(kv_kdata), encoding = 'utf-8'),
  251. parts,
  252. [idx + idxs_offs for idx in field_idxs],
  253. field_types,
  254. ), skip_sum = True)
  255. offs += field_size
  256. return offs
  257. def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
  258. tensor_fields = []
  259. for _ in range(count):
  260. field = self._get_tensor_info_field(offs)
  261. offs += sum(int(part.nbytes) for part in field.parts)
  262. tensor_fields.append(field)
  263. return offs, tensor_fields
  264. def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
  265. tensors = []
  266. tensor_names = set() # keep track of name to prevent duplicated tensors
  267. for field in fields:
  268. _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
  269. # check if there's any tensor having same name already in the list
  270. tensor_name = str(bytes(name_data), encoding = 'utf-8')
  271. if tensor_name in tensor_names:
  272. raise ValueError(f'Found duplicated tensor with name {tensor_name}')
  273. tensor_names.add(tensor_name)
  274. ggml_type = GGMLQuantizationType(raw_dtype[0])
  275. n_elems = int(np.prod(dims))
  276. np_dims = tuple(reversed(dims.tolist()))
  277. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  278. n_bytes = n_elems * type_size // block_size
  279. data_offs = int(start_offs + offset_tensor[0])
  280. item_type: npt.DTypeLike
  281. if ggml_type == GGMLQuantizationType.F16:
  282. item_count = n_elems
  283. item_type = np.float16
  284. elif ggml_type == GGMLQuantizationType.F32:
  285. item_count = n_elems
  286. item_type = np.float32
  287. elif ggml_type == GGMLQuantizationType.F64:
  288. item_count = n_elems
  289. item_type = np.float64
  290. elif ggml_type == GGMLQuantizationType.I8:
  291. item_count = n_elems
  292. item_type = np.int8
  293. elif ggml_type == GGMLQuantizationType.I16:
  294. item_count = n_elems
  295. item_type = np.int16
  296. elif ggml_type == GGMLQuantizationType.I32:
  297. item_count = n_elems
  298. item_type = np.int32
  299. elif ggml_type == GGMLQuantizationType.I64:
  300. item_count = n_elems
  301. item_type = np.int64
  302. else:
  303. item_count = n_bytes
  304. item_type = np.uint8
  305. np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
  306. tensors.append(ReaderTensor(
  307. name = tensor_name,
  308. tensor_type = ggml_type,
  309. shape = dims,
  310. n_elements = n_elems,
  311. n_bytes = n_bytes,
  312. data_offset = data_offs,
  313. data = self._get(data_offs, item_type, item_count).reshape(np_dims),
  314. field = field,
  315. ))
  316. self.tensors = tensors