gguf_reader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #
  2. # GGUF file reading/modification support. For API usage information,
  3. # please see the files scripts/ for some fairly simple examples.
  4. #
  5. from __future__ import annotations
  6. import os
  7. from collections import OrderedDict
  8. from typing import Any, Literal, NamedTuple, TypeVar, Union
  9. import numpy as np
  10. import numpy.typing as npt
  11. if __name__ == "__main__":
  12. import sys
  13. from pathlib import Path
  14. # Allow running file in package as a script.
  15. sys.path.insert(0, str(Path(__file__).parent.parent))
  16. from gguf.constants import (
  17. GGML_QUANT_SIZES,
  18. GGUF_DEFAULT_ALIGNMENT,
  19. GGUF_MAGIC,
  20. GGUF_VERSION,
  21. GGMLQuantizationType,
  22. GGUFValueType,
  23. )
  24. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  25. class ReaderField(NamedTuple):
  26. # Offset to start of this field.
  27. offset: int
  28. # Name of the field (not necessarily from file data).
  29. name: str
  30. # Data parts. Some types have multiple components, such as strings
  31. # that consist of a length followed by the string data.
  32. parts: list[npt.NDArray[Any]] = []
  33. # Indexes into parts that we can call the actual data. For example
  34. # an array of strings will be populated with indexes to the actual
  35. # string data.
  36. data: list[int] = [-1]
  37. types: list[GGUFValueType] = []
  38. class ReaderTensor(NamedTuple):
  39. name: str
  40. tensor_type: GGMLQuantizationType
  41. shape: npt.NDArray[np.uint32]
  42. n_elements: int
  43. n_bytes: int
  44. data_offset: int
  45. data: npt.NDArray[Any]
  46. field: ReaderField
  47. class GGUFReader:
  48. # I - same as host, S - swapped
  49. byte_order: Literal['I' | 'S'] = 'I'
  50. alignment: int = GGUF_DEFAULT_ALIGNMENT
  51. # Note: Internal helper, API may change.
  52. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  53. GGUFValueType.UINT8: np.uint8,
  54. GGUFValueType.INT8: np.int8,
  55. GGUFValueType.UINT16: np.uint16,
  56. GGUFValueType.INT16: np.int16,
  57. GGUFValueType.UINT32: np.uint32,
  58. GGUFValueType.INT32: np.int32,
  59. GGUFValueType.FLOAT32: np.float32,
  60. GGUFValueType.UINT64: np.uint64,
  61. GGUFValueType.INT64: np.int64,
  62. GGUFValueType.FLOAT64: np.float64,
  63. GGUFValueType.BOOL: np.bool_,
  64. }
  65. def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
  66. self.data = np.memmap(path, mode = mode)
  67. offs = 0
  68. if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
  69. raise ValueError('GGUF magic invalid')
  70. offs += 4
  71. temp_version = self._get(offs, np.uint32)
  72. if temp_version[0] & 65535 == 0:
  73. # If we get 0 here that means it's (probably) a GGUF file created for
  74. # the opposite byte order of the machine this script is running on.
  75. self.byte_order = 'S'
  76. temp_version = temp_version.newbyteorder(self.byte_order)
  77. version = temp_version[0]
  78. if version not in READER_SUPPORTED_VERSIONS:
  79. raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
  80. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  81. self.tensors: list[ReaderTensor] = []
  82. offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
  83. temp_counts = self._get(offs, np.uint64, 2)
  84. offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
  85. offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
  86. tensor_count, kv_count = temp_counts
  87. offs = self._build_fields(offs, kv_count)
  88. offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
  89. new_align = self.fields.get('general.alignment')
  90. if new_align is not None:
  91. if new_align.types != [GGUFValueType.UINT32]:
  92. raise ValueError('Bad type for general.alignment field')
  93. self.alignment = new_align.parts[-1][0]
  94. padding = offs % self.alignment
  95. if padding != 0:
  96. offs += self.alignment - padding
  97. self._build_tensors(offs, tensors_fields)
  98. _DT = TypeVar('_DT', bound = npt.DTypeLike)
  99. # Fetch a key/value metadata field by key.
  100. def get_field(self, key: str) -> Union[ReaderField, None]:
  101. return self.fields.get(key, None)
  102. # Fetch a tensor from the list by index.
  103. def get_tensor(self, idx: int) -> ReaderTensor:
  104. return self.tensors[idx]
  105. def _get(
  106. self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
  107. ) -> npt.NDArray[Any]:
  108. count = int(count)
  109. itemsize = int(np.empty([], dtype = dtype).itemsize)
  110. end_offs = offset + itemsize * count
  111. return (
  112. self.data[offset:end_offs]
  113. .view(dtype = dtype)[:count]
  114. .newbyteorder(override_order or self.byte_order)
  115. )
  116. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  117. if field.name in self.fields:
  118. raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
  119. self.fields[field.name] = field
  120. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  121. def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  122. slen = self._get(offset, np.uint64)
  123. return slen, self._get(offset + 8, np.uint8, slen[0])
  124. def _get_field_parts(
  125. self, orig_offs: int, raw_type: int,
  126. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  127. offs = orig_offs
  128. types: list[GGUFValueType] = []
  129. gtype = GGUFValueType(raw_type)
  130. types.append(gtype)
  131. # Handle strings.
  132. if gtype == GGUFValueType.STRING:
  133. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  134. size = sum(int(part.nbytes) for part in sparts)
  135. return size, sparts, [1], types
  136. # Check if it's a simple scalar type.
  137. nptype = self.gguf_scalar_to_np.get(gtype)
  138. if nptype is not None:
  139. val = self._get(offs, nptype)
  140. return int(val.nbytes), [val], [0], types
  141. # Handle arrays.
  142. if gtype == GGUFValueType.ARRAY:
  143. raw_itype = self._get(offs, np.uint32)
  144. offs += int(raw_itype.nbytes)
  145. alen = self._get(offs, np.uint64)
  146. offs += int(alen.nbytes)
  147. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  148. data_idxs: list[int] = []
  149. for idx in range(alen[0]):
  150. curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
  151. if idx == 0:
  152. types += curr_types
  153. idxs_offs = len(aparts)
  154. aparts += curr_parts
  155. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  156. offs += curr_size
  157. return offs - orig_offs, aparts, data_idxs, types
  158. # We can't deal with this one.
  159. raise ValueError('Unknown/unhandled field type {gtype}')
  160. def _get_tensor(self, orig_offs: int) -> ReaderField:
  161. offs = orig_offs
  162. name_len, name_data = self._get_str(offs)
  163. offs += int(name_len.nbytes + name_data.nbytes)
  164. n_dims = self._get(offs, np.uint32)
  165. offs += int(n_dims.nbytes)
  166. dims = self._get(offs, np.uint64, n_dims[0])
  167. offs += int(dims.nbytes)
  168. raw_dtype = self._get(offs, np.uint32)
  169. offs += int(raw_dtype.nbytes)
  170. offset_tensor = self._get(offs, np.uint64)
  171. offs += int(offset_tensor.nbytes)
  172. return ReaderField(
  173. orig_offs,
  174. str(bytes(name_data), encoding = 'utf-8'),
  175. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  176. [1, 3, 4, 5],
  177. )
  178. def _build_fields(self, offs: int, count: int) -> int:
  179. for _ in range(count):
  180. orig_offs = offs
  181. kv_klen, kv_kdata = self._get_str(offs)
  182. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  183. raw_kv_type = self._get(offs, np.uint32)
  184. offs += int(raw_kv_type.nbytes)
  185. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  186. idxs_offs = len(parts)
  187. field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
  188. parts += field_parts
  189. self._push_field(ReaderField(
  190. orig_offs,
  191. str(bytes(kv_kdata), encoding = 'utf-8'),
  192. parts,
  193. [idx + idxs_offs for idx in field_idxs],
  194. field_types,
  195. ), skip_sum = True)
  196. offs += field_size
  197. return offs
  198. def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
  199. tensor_fields = []
  200. for _ in range(count):
  201. field = self._get_tensor(offs)
  202. offs += sum(int(part.nbytes) for part in field.parts)
  203. tensor_fields.append(field)
  204. return offs, tensor_fields
  205. def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
  206. tensors = []
  207. for field in fields:
  208. _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
  209. ggml_type = GGMLQuantizationType(raw_dtype[0])
  210. n_elems = np.prod(dims)
  211. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  212. n_bytes = n_elems * type_size // block_size
  213. data_offs = int(start_offs + offset_tensor[0])
  214. item_type: npt.DTypeLike
  215. if ggml_type == GGMLQuantizationType.F16:
  216. item_count = n_elems
  217. item_type = np.float16
  218. elif ggml_type == GGMLQuantizationType.F32:
  219. item_count = n_elems
  220. item_type = np.float32
  221. elif ggml_type == GGMLQuantizationType.F64:
  222. item_count = n_elems
  223. item_type = np.float64
  224. elif ggml_type == GGMLQuantizationType.I8:
  225. item_count = n_elems
  226. item_type = np.int8
  227. elif ggml_type == GGMLQuantizationType.I16:
  228. item_count = n_elems
  229. item_type = np.int16
  230. elif ggml_type == GGMLQuantizationType.I32:
  231. item_count = n_elems
  232. item_type = np.int32
  233. elif ggml_type == GGMLQuantizationType.I64:
  234. item_count = n_elems
  235. item_type = np.int64
  236. else:
  237. item_count = n_bytes
  238. item_type = np.uint8
  239. tensors.append(ReaderTensor(
  240. name = str(bytes(name_data), encoding = 'utf-8'),
  241. tensor_type = ggml_type,
  242. shape = dims,
  243. n_elements = n_elems,
  244. n_bytes = n_bytes,
  245. data_offset = data_offs,
  246. data = self._get(data_offs, item_type, item_count),
  247. field = field,
  248. ))
  249. self.tensors = tensors