gguf_reader.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #
  2. # GGUF file reading/modification support. For API usage information,
  3. # please see the files scripts/ for some fairly simple examples.
  4. #
  5. from __future__ import annotations
  6. import logging
  7. import os
  8. from collections import OrderedDict
  9. from typing import Any, Literal, NamedTuple, TypeVar, Union
  10. import numpy as np
  11. import numpy.typing as npt
  12. from .quants import quant_shape_to_byte_shape
  13. if __name__ == "__main__":
  14. import sys
  15. from pathlib import Path
  16. # Allow running file in package as a script.
  17. sys.path.insert(0, str(Path(__file__).parent.parent))
  18. from gguf.constants import (
  19. GGML_QUANT_SIZES,
  20. GGUF_DEFAULT_ALIGNMENT,
  21. GGUF_MAGIC,
  22. GGUF_VERSION,
  23. GGMLQuantizationType,
  24. GGUFValueType,
  25. )
  26. logger = logging.getLogger(__name__)
  27. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  28. class ReaderField(NamedTuple):
  29. # Offset to start of this field.
  30. offset: int
  31. # Name of the field (not necessarily from file data).
  32. name: str
  33. # Data parts. Some types have multiple components, such as strings
  34. # that consist of a length followed by the string data.
  35. parts: list[npt.NDArray[Any]] = []
  36. # Indexes into parts that we can call the actual data. For example
  37. # an array of strings will be populated with indexes to the actual
  38. # string data.
  39. data: list[int] = [-1]
  40. types: list[GGUFValueType] = []
  41. class ReaderTensor(NamedTuple):
  42. name: str
  43. tensor_type: GGMLQuantizationType
  44. shape: npt.NDArray[np.uint32]
  45. n_elements: int
  46. n_bytes: int
  47. data_offset: int
  48. data: npt.NDArray[Any]
  49. field: ReaderField
  50. class GGUFReader:
  51. # I - same as host, S - swapped
  52. byte_order: Literal['I', 'S'] = 'I'
  53. alignment: int = GGUF_DEFAULT_ALIGNMENT
  54. data_offset: int
  55. # Note: Internal helper, API may change.
  56. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  57. GGUFValueType.UINT8: np.uint8,
  58. GGUFValueType.INT8: np.int8,
  59. GGUFValueType.UINT16: np.uint16,
  60. GGUFValueType.INT16: np.int16,
  61. GGUFValueType.UINT32: np.uint32,
  62. GGUFValueType.INT32: np.int32,
  63. GGUFValueType.FLOAT32: np.float32,
  64. GGUFValueType.UINT64: np.uint64,
  65. GGUFValueType.INT64: np.int64,
  66. GGUFValueType.FLOAT64: np.float64,
  67. GGUFValueType.BOOL: np.bool_,
  68. }
  69. def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
  70. self.data = np.memmap(path, mode = mode)
  71. offs = 0
  72. # Check for GGUF magic
  73. if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
  74. raise ValueError('GGUF magic invalid')
  75. offs += 4
  76. # Check GGUF version
  77. temp_version = self._get(offs, np.uint32)
  78. if temp_version[0] & 65535 == 0:
  79. # If we get 0 here that means it's (probably) a GGUF file created for
  80. # the opposite byte order of the machine this script is running on.
  81. self.byte_order = 'S'
  82. temp_version = temp_version.newbyteorder(self.byte_order)
  83. version = temp_version[0]
  84. if version not in READER_SUPPORTED_VERSIONS:
  85. raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
  86. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  87. self.tensors: list[ReaderTensor] = []
  88. offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
  89. # Check tensor count and kv count
  90. temp_counts = self._get(offs, np.uint64, 2)
  91. offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
  92. offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
  93. tensor_count, kv_count = temp_counts
  94. offs = self._build_fields(offs, kv_count)
  95. # Build Tensor Info Fields
  96. offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
  97. new_align = self.fields.get('general.alignment')
  98. if new_align is not None:
  99. if new_align.types != [GGUFValueType.UINT32]:
  100. raise ValueError('Bad type for general.alignment field')
  101. self.alignment = new_align.parts[-1][0]
  102. padding = offs % self.alignment
  103. if padding != 0:
  104. offs += self.alignment - padding
  105. self.data_offset = offs
  106. self._build_tensors(offs, tensors_fields)
  107. _DT = TypeVar('_DT', bound = npt.DTypeLike)
  108. # Fetch a key/value metadata field by key.
  109. def get_field(self, key: str) -> Union[ReaderField, None]:
  110. return self.fields.get(key, None)
  111. # Fetch a tensor from the list by index.
  112. def get_tensor(self, idx: int) -> ReaderTensor:
  113. return self.tensors[idx]
  114. def _get(
  115. self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
  116. ) -> npt.NDArray[Any]:
  117. count = int(count)
  118. itemsize = int(np.empty([], dtype = dtype).itemsize)
  119. end_offs = offset + itemsize * count
  120. return (
  121. self.data[offset:end_offs]
  122. .view(dtype = dtype)[:count]
  123. .newbyteorder(override_order or self.byte_order)
  124. )
  125. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  126. if field.name in self.fields:
  127. # TODO: add option to generate error on duplicate keys
  128. # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
  129. logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
  130. self.fields[field.name + '_{}'.format(field.offset)] = field
  131. else:
  132. self.fields[field.name] = field
  133. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  134. def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  135. slen = self._get(offset, np.uint64)
  136. return slen, self._get(offset + 8, np.uint8, slen[0])
  137. def _get_field_parts(
  138. self, orig_offs: int, raw_type: int,
  139. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  140. offs = orig_offs
  141. types: list[GGUFValueType] = []
  142. gtype = GGUFValueType(raw_type)
  143. types.append(gtype)
  144. # Handle strings.
  145. if gtype == GGUFValueType.STRING:
  146. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  147. size = sum(int(part.nbytes) for part in sparts)
  148. return size, sparts, [1], types
  149. # Check if it's a simple scalar type.
  150. nptype = self.gguf_scalar_to_np.get(gtype)
  151. if nptype is not None:
  152. val = self._get(offs, nptype)
  153. return int(val.nbytes), [val], [0], types
  154. # Handle arrays.
  155. if gtype == GGUFValueType.ARRAY:
  156. raw_itype = self._get(offs, np.uint32)
  157. offs += int(raw_itype.nbytes)
  158. alen = self._get(offs, np.uint64)
  159. offs += int(alen.nbytes)
  160. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  161. data_idxs: list[int] = []
  162. for idx in range(alen[0]):
  163. curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
  164. if idx == 0:
  165. types += curr_types
  166. idxs_offs = len(aparts)
  167. aparts += curr_parts
  168. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  169. offs += curr_size
  170. return offs - orig_offs, aparts, data_idxs, types
  171. # We can't deal with this one.
  172. raise ValueError('Unknown/unhandled field type {gtype}')
  173. def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
  174. offs = orig_offs
  175. # Get Tensor Name
  176. name_len, name_data = self._get_str(offs)
  177. offs += int(name_len.nbytes + name_data.nbytes)
  178. # Get Tensor Dimensions Count
  179. n_dims = self._get(offs, np.uint32)
  180. offs += int(n_dims.nbytes)
  181. # Get Tensor Dimension Array
  182. dims = self._get(offs, np.uint64, n_dims[0])
  183. offs += int(dims.nbytes)
  184. # Get Tensor Encoding Scheme Type
  185. raw_dtype = self._get(offs, np.uint32)
  186. offs += int(raw_dtype.nbytes)
  187. # Get Tensor Offset
  188. offset_tensor = self._get(offs, np.uint64)
  189. offs += int(offset_tensor.nbytes)
  190. return ReaderField(
  191. orig_offs,
  192. str(bytes(name_data), encoding = 'utf-8'),
  193. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  194. [1, 3, 4, 5],
  195. )
  196. def _build_fields(self, offs: int, count: int) -> int:
  197. for _ in range(count):
  198. orig_offs = offs
  199. kv_klen, kv_kdata = self._get_str(offs)
  200. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  201. raw_kv_type = self._get(offs, np.uint32)
  202. offs += int(raw_kv_type.nbytes)
  203. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  204. idxs_offs = len(parts)
  205. field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
  206. parts += field_parts
  207. self._push_field(ReaderField(
  208. orig_offs,
  209. str(bytes(kv_kdata), encoding = 'utf-8'),
  210. parts,
  211. [idx + idxs_offs for idx in field_idxs],
  212. field_types,
  213. ), skip_sum = True)
  214. offs += field_size
  215. return offs
  216. def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
  217. tensor_fields = []
  218. for _ in range(count):
  219. field = self._get_tensor_info_field(offs)
  220. offs += sum(int(part.nbytes) for part in field.parts)
  221. tensor_fields.append(field)
  222. return offs, tensor_fields
  223. def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
  224. tensors = []
  225. tensor_names = set() # keep track of name to prevent duplicated tensors
  226. for field in fields:
  227. _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
  228. # check if there's any tensor having same name already in the list
  229. tensor_name = str(bytes(name_data), encoding = 'utf-8')
  230. if tensor_name in tensor_names:
  231. raise ValueError(f'Found duplicated tensor with name {tensor_name}')
  232. tensor_names.add(tensor_name)
  233. ggml_type = GGMLQuantizationType(raw_dtype[0])
  234. n_elems = int(np.prod(dims))
  235. np_dims = tuple(reversed(dims.tolist()))
  236. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  237. n_bytes = n_elems * type_size // block_size
  238. data_offs = int(start_offs + offset_tensor[0])
  239. item_type: npt.DTypeLike
  240. if ggml_type == GGMLQuantizationType.F16:
  241. item_count = n_elems
  242. item_type = np.float16
  243. elif ggml_type == GGMLQuantizationType.F32:
  244. item_count = n_elems
  245. item_type = np.float32
  246. elif ggml_type == GGMLQuantizationType.F64:
  247. item_count = n_elems
  248. item_type = np.float64
  249. elif ggml_type == GGMLQuantizationType.I8:
  250. item_count = n_elems
  251. item_type = np.int8
  252. elif ggml_type == GGMLQuantizationType.I16:
  253. item_count = n_elems
  254. item_type = np.int16
  255. elif ggml_type == GGMLQuantizationType.I32:
  256. item_count = n_elems
  257. item_type = np.int32
  258. elif ggml_type == GGMLQuantizationType.I64:
  259. item_count = n_elems
  260. item_type = np.int64
  261. else:
  262. item_count = n_bytes
  263. item_type = np.uint8
  264. np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
  265. tensors.append(ReaderTensor(
  266. name = tensor_name,
  267. tensor_type = ggml_type,
  268. shape = dims,
  269. n_elements = n_elems,
  270. n_bytes = n_bytes,
  271. data_offset = data_offs,
  272. data = self._get(data_offs, item_type, item_count).reshape(np_dims),
  273. field = field,
  274. ))
  275. self.tensors = tensors