convert-hf-to-gguf.py 109 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import logging
  4. import argparse
  5. import contextlib
  6. import json
  7. import os
  8. import re
  9. import sys
  10. from enum import IntEnum
  11. from pathlib import Path
  12. from hashlib import sha256
  13. from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
  14. import numpy as np
  15. import torch
  16. if TYPE_CHECKING:
  17. from torch import Tensor
  18. if 'NO_LOCAL_GGUF' not in os.environ:
  19. sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
  20. import gguf
  21. from convert import LlamaHfVocab
  22. logger = logging.getLogger("hf-to-gguf")
  23. ###### MODEL DEFINITIONS ######
  24. class SentencePieceTokenTypes(IntEnum):
  25. NORMAL = 1
  26. UNKNOWN = 2
  27. CONTROL = 3
  28. USER_DEFINED = 4
  29. UNUSED = 5
  30. BYTE = 6
  31. AnyModel = TypeVar("AnyModel", bound="type[Model]")
  32. class Model:
  33. _model_classes: dict[str, type[Model]] = {}
  34. dir_model: Path
  35. ftype: int
  36. is_big_endian: bool
  37. endianess: gguf.GGUFEndian
  38. use_temp_file: bool
  39. lazy: bool
  40. part_names: list[str]
  41. is_safetensors: bool
  42. hparams: dict[str, Any]
  43. block_count: int
  44. tensor_map: gguf.TensorNameMap
  45. tensor_names: set[str] | None
  46. fname_out: Path
  47. gguf_writer: gguf.GGUFWriter
  48. # subclasses should define this!
  49. model_arch: gguf.MODEL_ARCH
  50. def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
  51. if type(self) is Model:
  52. raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
  53. self.dir_model = dir_model
  54. self.ftype = ftype
  55. self.is_big_endian = is_big_endian
  56. self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  57. self.use_temp_file = use_temp_file
  58. self.lazy = not eager
  59. self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors")
  60. self.is_safetensors = len(self.part_names) > 0
  61. if not self.is_safetensors:
  62. self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
  63. self.hparams = Model.load_hparams(self.dir_model)
  64. self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
  65. self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
  66. self.tensor_names = None
  67. if self.ftype == gguf.LlamaFileType.GUESSED:
  68. # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
  69. _, first_tensor = next(self.get_tensors())
  70. if first_tensor.dtype == torch.float16:
  71. logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
  72. self.ftype = gguf.LlamaFileType.MOSTLY_F16
  73. else:
  74. logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
  75. self.ftype = gguf.LlamaFileType.MOSTLY_BF16
  76. ftype_up: str = self.ftype.name.partition("_")[2].upper()
  77. ftype_lw: str = ftype_up.lower()
  78. # allow templating the file name with the output ftype, useful with the "auto" ftype
  79. self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
  80. self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
  81. @classmethod
  82. def __init_subclass__(cls):
  83. # can't use an abstract property, because overriding it without type errors
  84. # would require using decorated functions instead of simply defining the property
  85. if "model_arch" not in cls.__dict__:
  86. raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
  87. def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
  88. key = next((k for k in keys if k in self.hparams), None)
  89. if key is not None:
  90. return self.hparams[key]
  91. if optional:
  92. return None
  93. raise KeyError(f"could not find any of: {keys}")
  94. def set_vocab(self):
  95. self._set_vocab_gpt2()
  96. def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
  97. tensor_names_from_parts: set[str] = set()
  98. if len(self.part_names) > 1:
  99. self.tensor_names = set()
  100. index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
  101. index_name += ".index.json"
  102. logger.info(f"gguf: loading model weight map from '{index_name}'")
  103. with open(self.dir_model / index_name, "r", encoding="utf-8") as f:
  104. index: dict[str, Any] = json.load(f)
  105. weight_map = index.get("weight_map")
  106. if weight_map is None or not isinstance(weight_map, dict):
  107. raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
  108. self.tensor_names.update(weight_map.keys())
  109. else:
  110. self.tensor_names = tensor_names_from_parts
  111. for part_name in self.part_names:
  112. logger.info(f"gguf: loading model part '{part_name}'")
  113. ctx: ContextManager[Any]
  114. if self.is_safetensors:
  115. from safetensors import safe_open
  116. ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
  117. else:
  118. ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
  119. with ctx as model_part:
  120. tensor_names_from_parts.update(model_part.keys())
  121. for name in model_part.keys():
  122. data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
  123. if self.lazy:
  124. data = LazyTorchTensor.from_eager(data)
  125. yield name, data
  126. # only verify tensor name presence; it doesn't matter if they are not in the right files
  127. if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
  128. raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
  129. def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
  130. if key not in gguf.MODEL_TENSORS[self.model_arch]:
  131. raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
  132. name: str = gguf.TENSOR_NAMES[key]
  133. if "{bid}" in name:
  134. assert bid is not None
  135. name = name.format(bid=bid)
  136. return name + suffix
  137. def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
  138. if key not in gguf.MODEL_TENSORS[self.model_arch]:
  139. return False
  140. key_name: str = gguf.TENSOR_NAMES[key]
  141. if "{bid}" in key_name:
  142. if bid is None:
  143. return False
  144. key_name = key_name.format(bid=bid)
  145. else:
  146. if bid is not None:
  147. return False
  148. return name == (key_name + suffix)
  149. def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
  150. new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
  151. if new_name is None:
  152. raise ValueError(f"Can not map tensor {name!r}")
  153. return new_name
  154. def set_gguf_parameters(self):
  155. self.gguf_writer.add_name(self.dir_model.name)
  156. self.gguf_writer.add_block_count(self.block_count)
  157. if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
  158. self.gguf_writer.add_context_length(n_ctx)
  159. logger.info(f"gguf: context length = {n_ctx}")
  160. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  161. self.gguf_writer.add_embedding_length(n_embd)
  162. logger.info(f"gguf: embedding length = {n_embd}")
  163. if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
  164. self.gguf_writer.add_feed_forward_length(n_ff)
  165. logger.info(f"gguf: feed forward length = {n_ff}")
  166. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  167. self.gguf_writer.add_head_count(n_head)
  168. logger.info(f"gguf: head count = {n_head}")
  169. if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
  170. self.gguf_writer.add_head_count_kv(n_head_kv)
  171. logger.info(f"gguf: key-value head count = {n_head_kv}")
  172. if (rope_theta := self.hparams.get("rope_theta")) is not None:
  173. self.gguf_writer.add_rope_freq_base(rope_theta)
  174. logger.info(f"gguf: rope theta = {rope_theta}")
  175. if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
  176. self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
  177. logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
  178. if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
  179. self.gguf_writer.add_layer_norm_eps(f_norm_eps)
  180. logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
  181. if (n_experts := self.hparams.get("num_local_experts")) is not None:
  182. self.gguf_writer.add_expert_count(n_experts)
  183. logger.info(f"gguf: expert count = {n_experts}")
  184. if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
  185. self.gguf_writer.add_expert_used_count(n_experts_used)
  186. logger.info(f"gguf: experts used count = {n_experts_used}")
  187. self.gguf_writer.add_file_type(self.ftype)
  188. logger.info(f"gguf: file type = {self.ftype}")
  189. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  190. del bid # unused
  191. return [(self.map_tensor_name(name), data_torch)]
  192. def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
  193. del name, new_name, bid, n_dims # unused
  194. return False
  195. def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
  196. del name, new_name, bid, n_dims # unused
  197. return False
  198. def write_tensors(self):
  199. max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
  200. for name, data_torch in self.get_tensors():
  201. # we don't need these
  202. if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
  203. continue
  204. old_dtype = data_torch.dtype
  205. # convert any unsupported data types to float32
  206. if data_torch.dtype not in (torch.float16, torch.float32):
  207. data_torch = data_torch.to(torch.float32)
  208. # use the first number-like part of the tensor name as the block id
  209. bid = None
  210. for part in name.split("."):
  211. if part.isdecimal():
  212. bid = int(part)
  213. break
  214. for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
  215. data: np.ndarray = data # type hint
  216. n_dims = len(data.shape)
  217. data_dtype = data.dtype
  218. data_qtype: gguf.GGMLQuantizationType | None = None
  219. # when both are True, f32 should win
  220. extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
  221. extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
  222. # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
  223. # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
  224. extra_f32 = any(cond for cond in (
  225. extra_f32,
  226. n_dims == 1,
  227. new_name.endswith("_norm.weight"),
  228. ))
  229. # Some tensor types are always in float32
  230. extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
  231. gguf.MODEL_TENSOR.FFN_GATE_INP,
  232. gguf.MODEL_TENSOR.POS_EMBD,
  233. gguf.MODEL_TENSOR.TOKEN_TYPES,
  234. ))
  235. # if f16 desired, convert any float32 2-dim weight tensors to float16
  236. extra_f16 = any(cond for cond in (
  237. extra_f16,
  238. (name.endswith(".weight") and n_dims >= 2),
  239. ))
  240. if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
  241. if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
  242. data = gguf.quantize_bf16(data)
  243. assert data.dtype == np.int16
  244. data_qtype = gguf.GGMLQuantizationType.BF16
  245. elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
  246. data = gguf.quantize_q8_0(data)
  247. assert data.dtype == np.uint8
  248. data_qtype = gguf.GGMLQuantizationType.Q8_0
  249. else: # default to float16 for quantized tensors
  250. if data_dtype != np.float16:
  251. data = data.astype(np.float16)
  252. data_qtype = gguf.GGMLQuantizationType.F16
  253. if data_qtype is None: # by default, convert to float32
  254. if data_dtype != np.float32:
  255. data = data.astype(np.float32)
  256. data_qtype = gguf.GGMLQuantizationType.F32
  257. block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
  258. # reverse shape to make it similar to the internal ggml dimension order
  259. shape_str = f"""{{{', '.join(str(n) for n in reversed(
  260. (*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
  261. )}}}"""
  262. # n_dims is implicit in the shape
  263. logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
  264. self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
  265. def write(self):
  266. self.write_tensors()
  267. self.gguf_writer.write_header_to_file()
  268. self.gguf_writer.write_kv_data_to_file()
  269. self.gguf_writer.write_tensors_to_file(progress=True)
  270. self.gguf_writer.close()
  271. def write_vocab(self):
  272. self.gguf_writer.write_header_to_file()
  273. self.gguf_writer.write_kv_data_to_file()
  274. self.gguf_writer.close()
  275. @staticmethod
  276. def get_model_part_names(dir_model: Path, suffix: str) -> list[str]:
  277. part_names: list[str] = []
  278. for filename in os.listdir(dir_model):
  279. if filename.endswith(suffix):
  280. part_names.append(filename)
  281. part_names.sort()
  282. return part_names
  283. @staticmethod
  284. def load_hparams(dir_model: Path):
  285. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  286. return json.load(f)
  287. @classmethod
  288. def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
  289. assert names
  290. def func(modelcls: AnyModel) -> AnyModel:
  291. for name in names:
  292. cls._model_classes[name] = modelcls
  293. return modelcls
  294. return func
  295. @classmethod
  296. def from_model_architecture(cls, arch: str) -> type[Model]:
  297. try:
  298. return cls._model_classes[arch]
  299. except KeyError:
  300. raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
  301. # used for GPT-2 BPE and WordPiece vocabs
  302. def get_vocab_base(self) -> tuple[list[str], list[int], str]:
  303. tokens: list[str] = []
  304. toktypes: list[int] = []
  305. from transformers import AutoTokenizer
  306. tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
  307. vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
  308. assert max(tokenizer.vocab.values()) < vocab_size
  309. tokpre = self.get_vocab_base_pre(tokenizer)
  310. reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  311. added_vocab = tokenizer.get_added_vocab()
  312. for i in range(vocab_size):
  313. if i not in reverse_vocab:
  314. tokens.append(f"[PAD{i}]")
  315. toktypes.append(gguf.TokenType.USER_DEFINED)
  316. elif reverse_vocab[i] in added_vocab:
  317. tokens.append(reverse_vocab[i])
  318. if tokenizer.added_tokens_decoder[i].special:
  319. toktypes.append(gguf.TokenType.CONTROL)
  320. else:
  321. toktypes.append(gguf.TokenType.USER_DEFINED)
  322. else:
  323. tokens.append(reverse_vocab[i])
  324. toktypes.append(gguf.TokenType.NORMAL)
  325. return tokens, toktypes, tokpre
  326. # NOTE: this function is generated by convert-hf-to-gguf-update.py
  327. # do not modify it manually!
  328. # ref: https://github.com/ggerganov/llama.cpp/pull/6920
  329. # Marker: Start get_vocab_base_pre
  330. def get_vocab_base_pre(self, tokenizer) -> str:
  331. # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that
  332. # is specific for the BPE pre-tokenizer used by the model
  333. # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can
  334. # use in llama.cpp to implement the same pre-tokenizer
  335. chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
  336. chktok = tokenizer.encode(chktxt)
  337. chkhsh = sha256(str(chktok).encode()).hexdigest()
  338. logger.debug(f"chktok: {chktok}")
  339. logger.debug(f"chkhsh: {chkhsh}")
  340. res = None
  341. # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
  342. # or pull the latest version of the model from Huggingface
  343. # don't edit the hashes manually!
  344. if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
  345. # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
  346. res = "llama-bpe"
  347. if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
  348. # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
  349. res = "deepseek-llm"
  350. if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821":
  351. # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base
  352. res = "deepseek-coder"
  353. if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed":
  354. # ref: https://huggingface.co/tiiuae/falcon-7b
  355. res = "falcon"
  356. if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
  357. # ref: https://huggingface.co/BAAI/bge-small-en-v1.5
  358. res = "bert-bge"
  359. if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
  360. # ref: https://huggingface.co/mosaicml/mpt-7b
  361. res = "mpt"
  362. if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34":
  363. # ref: https://huggingface.co/bigcode/starcoder2-3b
  364. res = "starcoder"
  365. if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454":
  366. # ref: https://huggingface.co/openai-community/gpt2
  367. res = "gpt-2"
  368. if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff":
  369. # ref: https://huggingface.co/smallcloudai/Refact-1_6-base
  370. res = "refact"
  371. if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
  372. # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
  373. res = "command-r"
  374. if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
  375. # ref: https://huggingface.co/Qwen/Qwen1.5-7B
  376. res = "qwen2"
  377. if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
  378. # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
  379. res = "olmo"
  380. if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
  381. # ref: https://huggingface.co/databricks/dbrx-base
  382. res = "dbrx"
  383. if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
  384. # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
  385. res = "jina-v2-en"
  386. if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643":
  387. # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es
  388. res = "jina-v2-es"
  389. if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
  390. # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
  391. res = "jina-v2-de"
  392. if res is None:
  393. logger.warning("\n")
  394. logger.warning("**************************************************************************************")
  395. logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!")
  396. logger.warning("** There are 2 possible reasons for this:")
  397. logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet")
  398. logger.warning("** - the pre-tokenization config has changed upstream")
  399. logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
  400. logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
  401. logger.warning("**")
  402. logger.warning(f"** chkhsh: {chkhsh}")
  403. logger.warning("**************************************************************************************")
  404. logger.warning("\n")
  405. raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
  406. logger.debug(f"tokenizer.ggml.pre: {repr(res)}")
  407. logger.debug(f"chkhsh: {chkhsh}")
  408. return res
  409. # Marker: End get_vocab_base_pre
  410. def _set_vocab_gpt2(self) -> None:
  411. tokens, toktypes, tokpre = self.get_vocab_base()
  412. self.gguf_writer.add_tokenizer_model("gpt2")
  413. self.gguf_writer.add_tokenizer_pre(tokpre)
  414. self.gguf_writer.add_token_list(tokens)
  415. self.gguf_writer.add_token_types(toktypes)
  416. special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
  417. special_vocab.add_to_gguf(self.gguf_writer)
  418. def _set_vocab_qwen(self):
  419. dir_model = self.dir_model
  420. hparams = self.hparams
  421. tokens: list[str] = []
  422. toktypes: list[int] = []
  423. from transformers import AutoTokenizer
  424. tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
  425. vocab_size = hparams["vocab_size"]
  426. assert max(tokenizer.get_vocab().values()) < vocab_size
  427. tokpre = self.get_vocab_base_pre(tokenizer)
  428. merges = []
  429. vocab = {}
  430. mergeable_ranks = tokenizer.mergeable_ranks
  431. for token, rank in mergeable_ranks.items():
  432. vocab[QwenModel.token_bytes_to_string(token)] = rank
  433. if len(token) == 1:
  434. continue
  435. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
  436. assert len(merged) == 2
  437. merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
  438. # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
  439. added_vocab = tokenizer.special_tokens
  440. reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
  441. for i in range(vocab_size):
  442. if i not in reverse_vocab:
  443. tokens.append(f"[PAD{i}]")
  444. toktypes.append(gguf.TokenType.USER_DEFINED)
  445. elif reverse_vocab[i] in added_vocab:
  446. tokens.append(reverse_vocab[i])
  447. toktypes.append(gguf.TokenType.CONTROL)
  448. else:
  449. tokens.append(reverse_vocab[i])
  450. toktypes.append(gguf.TokenType.NORMAL)
  451. self.gguf_writer.add_tokenizer_model("gpt2")
  452. self.gguf_writer.add_tokenizer_pre(tokpre)
  453. self.gguf_writer.add_token_list(tokens)
  454. self.gguf_writer.add_token_types(toktypes)
  455. special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
  456. special_vocab.merges = merges
  457. # only add special tokens when they were not already loaded from config.json
  458. if len(special_vocab.special_token_ids) == 0:
  459. special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
  460. special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
  461. # this one is usually not in config.json anyway
  462. special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
  463. special_vocab.add_to_gguf(self.gguf_writer)
  464. def _set_vocab_sentencepiece(self):
  465. from sentencepiece import SentencePieceProcessor
  466. tokenizer_path = self.dir_model / 'tokenizer.model'
  467. tokens: list[bytes] = []
  468. scores: list[float] = []
  469. toktypes: list[int] = []
  470. if not tokenizer_path.is_file():
  471. raise FileNotFoundError(f"File not found: {tokenizer_path}")
  472. tokenizer = SentencePieceProcessor()
  473. tokenizer.LoadFromFile(str(tokenizer_path))
  474. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  475. for token_id in range(tokenizer.vocab_size()):
  476. piece = tokenizer.IdToPiece(token_id)
  477. text = piece.encode("utf-8")
  478. score = tokenizer.GetScore(token_id)
  479. toktype = SentencePieceTokenTypes.NORMAL
  480. if tokenizer.IsUnknown(token_id):
  481. toktype = SentencePieceTokenTypes.UNKNOWN
  482. elif tokenizer.IsControl(token_id):
  483. toktype = SentencePieceTokenTypes.CONTROL
  484. elif tokenizer.IsUnused(token_id):
  485. toktype = SentencePieceTokenTypes.UNUSED
  486. elif tokenizer.IsByte(token_id):
  487. toktype = SentencePieceTokenTypes.BYTE
  488. tokens.append(text)
  489. scores.append(score)
  490. toktypes.append(toktype)
  491. added_tokens_file = self.dir_model / 'added_tokens.json'
  492. if added_tokens_file.is_file():
  493. with open(added_tokens_file, "r", encoding="utf-8") as f:
  494. added_tokens_json = json.load(f)
  495. for key in added_tokens_json:
  496. key = key.encode("utf-8")
  497. if key not in tokens:
  498. tokens.append(key)
  499. scores.append(-1000.0)
  500. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  501. if vocab_size > len(tokens):
  502. pad_count = vocab_size - len(tokens)
  503. logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
  504. for i in range(1, pad_count + 1):
  505. tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
  506. scores.append(-1000.0)
  507. toktypes.append(SentencePieceTokenTypes.UNUSED)
  508. assert len(tokens) == vocab_size
  509. self.gguf_writer.add_tokenizer_model("llama")
  510. self.gguf_writer.add_tokenizer_pre("default")
  511. self.gguf_writer.add_token_list(tokens)
  512. self.gguf_writer.add_token_scores(scores)
  513. self.gguf_writer.add_token_types(toktypes)
  514. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  515. special_vocab.add_to_gguf(self.gguf_writer)
  516. def _set_vocab_llama_hf(self):
  517. vocab = LlamaHfVocab(self.dir_model)
  518. tokens = []
  519. scores = []
  520. toktypes = []
  521. for text, score, toktype in vocab.all_tokens():
  522. tokens.append(text)
  523. scores.append(score)
  524. toktypes.append(toktype)
  525. assert len(tokens) == vocab.vocab_size
  526. self.gguf_writer.add_tokenizer_model("llama")
  527. self.gguf_writer.add_tokenizer_pre("default")
  528. self.gguf_writer.add_token_list(tokens)
  529. self.gguf_writer.add_token_scores(scores)
  530. self.gguf_writer.add_token_types(toktypes)
  531. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  532. special_vocab.add_to_gguf(self.gguf_writer)
  533. @Model.register("GPTNeoXForCausalLM")
  534. class GPTNeoXModel(Model):
  535. model_arch = gguf.MODEL_ARCH.GPTNEOX
  536. def set_gguf_parameters(self):
  537. block_count = self.hparams["num_hidden_layers"]
  538. self.gguf_writer.add_name(self.dir_model.name)
  539. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  540. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  541. self.gguf_writer.add_block_count(block_count)
  542. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  543. self.gguf_writer.add_rope_dimension_count(
  544. int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
  545. )
  546. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  547. self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
  548. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  549. @Model.register("BloomForCausalLM")
  550. class BloomModel(Model):
  551. model_arch = gguf.MODEL_ARCH.BLOOM
  552. def set_gguf_parameters(self):
  553. self.gguf_writer.add_name("Bloom")
  554. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  555. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  556. self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
  557. self.gguf_writer.add_embedding_length(n_embed)
  558. self.gguf_writer.add_feed_forward_length(4 * n_embed)
  559. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  560. self.gguf_writer.add_head_count(n_head)
  561. self.gguf_writer.add_head_count_kv(n_head)
  562. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  563. self.gguf_writer.add_file_type(self.ftype)
  564. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  565. del bid # unused
  566. n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
  567. n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
  568. name = re.sub(r'transformer\.', '', name)
  569. tensors: list[tuple[str, Tensor]] = []
  570. if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
  571. # Map bloom-style qkv_linear to gpt-style qkv_linear
  572. # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
  573. # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
  574. qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
  575. data_torch = torch.cat(
  576. (
  577. qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
  578. qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
  579. qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
  580. ),
  581. dim=0,
  582. )
  583. logger.info("re-format attention.linear_qkv.weight")
  584. elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
  585. qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
  586. data_torch = torch.cat(
  587. (
  588. qkv_bias[:, 0, :].reshape((n_embed,)),
  589. qkv_bias[:, 1, :].reshape((n_embed,)),
  590. qkv_bias[:, 2, :].reshape((n_embed,)),
  591. ),
  592. dim=0,
  593. )
  594. logger.info("re-format attention.linear_qkv.bias")
  595. tensors.append((self.map_tensor_name(name), data_torch))
  596. if name == "word_embeddings.weight":
  597. assert self.tensor_names is not None
  598. # TODO: tie them at runtime, don't duplicate in the model file
  599. if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
  600. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
  601. return tensors
  602. @Model.register("MPTForCausalLM")
  603. class MPTModel(Model):
  604. model_arch = gguf.MODEL_ARCH.MPT
  605. def set_vocab(self):
  606. try:
  607. self._set_vocab_gpt2()
  608. except Exception:
  609. # Fallback for SEA-LION model
  610. self._set_vocab_sentencepiece()
  611. self.gguf_writer.add_add_bos_token(False)
  612. self.gguf_writer.add_pad_token_id(3)
  613. self.gguf_writer.add_eos_token_id(1)
  614. self.gguf_writer.add_unk_token_id(0)
  615. def set_gguf_parameters(self):
  616. block_count = self.hparams["n_layers"]
  617. self.gguf_writer.add_name(self.dir_model.name)
  618. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  619. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  620. self.gguf_writer.add_block_count(block_count)
  621. self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
  622. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  623. if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
  624. self.gguf_writer.add_head_count_kv(kv_n_heads)
  625. self.gguf_writer.add_layer_norm_eps(1e-5)
  626. if self.hparams["attn_config"]["clip_qkv"] is not None:
  627. self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
  628. if self.hparams["attn_config"]["alibi"]:
  629. self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
  630. else:
  631. self.gguf_writer.add_max_alibi_bias(0.0)
  632. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  633. del bid # unused
  634. if "scales" in name:
  635. new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias", ".scales"))
  636. new_name = new_name.replace("scales", "act.scales")
  637. else:
  638. new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias"))
  639. return [(new_name, data_torch)]
  640. @Model.register("OrionForCausalLM")
  641. class OrionModel(Model):
  642. model_arch = gguf.MODEL_ARCH.ORION
  643. def set_vocab(self):
  644. self._set_vocab_sentencepiece()
  645. def set_gguf_parameters(self):
  646. block_count = self.hparams["num_hidden_layers"]
  647. head_count = self.hparams["num_attention_heads"]
  648. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  649. hf_repo = self.hparams.get("_name_or_path", "")
  650. ctx_length = 0
  651. if "max_sequence_length" in self.hparams:
  652. ctx_length = self.hparams["max_sequence_length"]
  653. elif "max_position_embeddings" in self.hparams:
  654. ctx_length = self.hparams["max_position_embeddings"]
  655. elif "model_max_length" in self.hparams:
  656. ctx_length = self.hparams["model_max_length"]
  657. else:
  658. raise ValueError("gguf: can not find ctx length parameter.")
  659. self.gguf_writer.add_file_type(self.ftype)
  660. self.gguf_writer.add_name(self.dir_model.name)
  661. self.gguf_writer.add_source_hf_repo(hf_repo)
  662. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  663. self.gguf_writer.add_context_length(ctx_length)
  664. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  665. self.gguf_writer.add_block_count(block_count)
  666. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  667. self.gguf_writer.add_head_count(head_count)
  668. self.gguf_writer.add_head_count_kv(head_count_kv)
  669. # note: config provides rms norm but it is actually layer norm
  670. # ref: https://huggingface.co/OrionStarAI/Orion-14B-Chat/blob/276a17221ce42beb45f66fac657a41540e71f4f5/modeling_orion.py#L570-L571
  671. self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
  672. @Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM")
  673. class BaichuanModel(Model):
  674. model_arch = gguf.MODEL_ARCH.BAICHUAN
  675. def set_vocab(self):
  676. self._set_vocab_sentencepiece()
  677. def set_gguf_parameters(self):
  678. block_count = self.hparams["num_hidden_layers"]
  679. head_count = self.hparams["num_attention_heads"]
  680. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  681. hf_repo = self.hparams.get("_name_or_path", "")
  682. ctx_length = 0
  683. if "max_sequence_length" in self.hparams:
  684. ctx_length = self.hparams["max_sequence_length"]
  685. elif "max_position_embeddings" in self.hparams:
  686. ctx_length = self.hparams["max_position_embeddings"]
  687. elif "model_max_length" in self.hparams:
  688. ctx_length = self.hparams["model_max_length"]
  689. else:
  690. raise ValueError("gguf: can not find ctx length parameter.")
  691. self.gguf_writer.add_name(self.dir_model.name)
  692. self.gguf_writer.add_source_hf_repo(hf_repo)
  693. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  694. self.gguf_writer.add_context_length(ctx_length)
  695. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  696. self.gguf_writer.add_block_count(block_count)
  697. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  698. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  699. self.gguf_writer.add_head_count(head_count)
  700. self.gguf_writer.add_head_count_kv(head_count_kv)
  701. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  702. self.gguf_writer.add_file_type(self.ftype)
  703. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  704. if self.hparams["rope_scaling"].get("type") == "linear":
  705. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  706. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  707. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  708. head_count = self.hparams["num_attention_heads"]
  709. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  710. tensors: list[tuple[str, Tensor]] = []
  711. if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight":
  712. logger.info(f"Unpacking and permuting layer {bid}")
  713. tensors = [
  714. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid),
  715. self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)),
  716. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid),
  717. self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv)),
  718. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid),
  719. self._reverse_hf_part(data_torch, 2)),
  720. ]
  721. else:
  722. tensors = [(self.map_tensor_name(name), data_torch)]
  723. return tensors
  724. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  725. if n_kv_head is not None and n_head != n_kv_head:
  726. n_head //= n_kv_head
  727. return (
  728. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  729. .swapaxes(1, 2)
  730. .reshape(weights.shape)
  731. )
  732. def _reverse_hf_permute_part(
  733. self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
  734. ) -> Tensor:
  735. r = weights.shape[0] // 3
  736. return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
  737. def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
  738. r = weights.shape[0] // 3
  739. return weights[r * n_part:r * n_part + r, ...]
  740. @Model.register("XverseForCausalLM")
  741. class XverseModel(Model):
  742. model_arch = gguf.MODEL_ARCH.XVERSE
  743. def set_vocab(self):
  744. assert (self.dir_model / "tokenizer.json").is_file()
  745. dir_model = self.dir_model
  746. hparams = self.hparams
  747. tokens: list[bytes] = []
  748. toktypes: list[int] = []
  749. from transformers import AutoTokenizer
  750. tokenizer = AutoTokenizer.from_pretrained(dir_model)
  751. vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
  752. assert max(tokenizer.vocab.values()) < vocab_size
  753. reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
  754. added_vocab = tokenizer.get_added_vocab()
  755. for token_id in range(vocab_size):
  756. token_text = reverse_vocab[token_id].encode('utf-8')
  757. # replace "\x00" to string with length > 0
  758. if token_text == b"\x00":
  759. toktype = gguf.TokenType.BYTE # special
  760. token_text = f"<{token_text}>".encode('utf-8')
  761. elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
  762. toktype = gguf.TokenType.BYTE # special
  763. elif reverse_vocab[token_id] in added_vocab:
  764. if tokenizer.added_tokens_decoder[token_id].special:
  765. toktype = gguf.TokenType.CONTROL
  766. else:
  767. toktype = gguf.TokenType.USER_DEFINED
  768. else:
  769. toktype = gguf.TokenType.NORMAL
  770. tokens.append(token_text)
  771. toktypes.append(toktype)
  772. self.gguf_writer.add_tokenizer_model("llama")
  773. self.gguf_writer.add_tokenizer_pre("default")
  774. self.gguf_writer.add_token_list(tokens)
  775. self.gguf_writer.add_token_types(toktypes)
  776. special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
  777. special_vocab.add_to_gguf(self.gguf_writer)
  778. def set_gguf_parameters(self):
  779. block_count = self.hparams["num_hidden_layers"]
  780. head_count = self.hparams["num_attention_heads"]
  781. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  782. hf_repo = self.hparams.get("_name_or_path", "")
  783. ctx_length = 0
  784. if "max_sequence_length" in self.hparams:
  785. ctx_length = self.hparams["max_sequence_length"]
  786. elif "max_position_embeddings" in self.hparams:
  787. ctx_length = self.hparams["max_position_embeddings"]
  788. elif "model_max_length" in self.hparams:
  789. ctx_length = self.hparams["model_max_length"]
  790. else:
  791. raise ValueError("gguf: can not find ctx length parameter.")
  792. self.gguf_writer.add_name(self.dir_model.name)
  793. self.gguf_writer.add_source_hf_repo(hf_repo)
  794. self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
  795. self.gguf_writer.add_context_length(ctx_length)
  796. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  797. self.gguf_writer.add_block_count(block_count)
  798. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  799. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  800. self.gguf_writer.add_head_count(head_count)
  801. self.gguf_writer.add_head_count_kv(head_count_kv)
  802. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  803. self.gguf_writer.add_file_type(self.ftype)
  804. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  805. if self.hparams["rope_scaling"].get("type") == "linear":
  806. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  807. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  808. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  809. del bid # unused
  810. head_count = self.hparams["num_attention_heads"]
  811. head_count_kv = self.hparams.get("num_key_value_heads", head_count)
  812. # HF models permute some of the tensors, so we need to undo that
  813. if name.endswith("q_proj.weight"):
  814. data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
  815. if name.endswith("k_proj.weight"):
  816. data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)
  817. return [(self.map_tensor_name(name), data_torch)]
  818. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  819. if n_kv_head is not None and n_head != n_kv_head:
  820. n_head //= n_kv_head
  821. return (
  822. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  823. .swapaxes(1, 2)
  824. .reshape(weights.shape)
  825. )
  826. @Model.register("FalconForCausalLM", "RWForCausalLM")
  827. class FalconModel(Model):
  828. model_arch = gguf.MODEL_ARCH.FALCON
  829. def set_gguf_parameters(self):
  830. block_count = self.hparams.get("num_hidden_layers")
  831. if block_count is None:
  832. block_count = self.hparams["n_layer"] # old name
  833. n_head = self.hparams.get("num_attention_heads")
  834. if n_head is None:
  835. n_head = self.hparams["n_head"] # old name
  836. n_head_kv = self.hparams.get("num_kv_heads")
  837. if n_head_kv is None:
  838. n_head_kv = self.hparams.get("n_head_kv", 1) # old name
  839. self.gguf_writer.add_name("Falcon")
  840. self.gguf_writer.add_context_length(2048) # not in config.json
  841. self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
  842. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  843. self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
  844. self.gguf_writer.add_block_count(block_count)
  845. self.gguf_writer.add_head_count(n_head)
  846. self.gguf_writer.add_head_count_kv(n_head_kv)
  847. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  848. self.gguf_writer.add_file_type(self.ftype)
  849. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  850. del bid # unused
  851. # QKV tensor transform
  852. # The original query_key_value tensor contains n_head_kv "kv groups",
  853. # each consisting of n_head/n_head_kv query weights followed by one key
  854. # and one value weight (shared by all query heads in the kv group).
  855. # This layout makes it a big pain to work with in GGML.
  856. # So we rearrange them here,, so that we have n_head query weights
  857. # followed by n_head_kv key weights followed by n_head_kv value weights,
  858. # in contiguous fashion.
  859. # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
  860. if "query_key_value" in name:
  861. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  862. n_head_kv = self.find_hparam(["num_kv_heads", "n_head_kv"], optional=True) or 1
  863. head_dim = self.hparams["hidden_size"] // n_head
  864. qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
  865. q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
  866. k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
  867. v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
  868. data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
  869. return [(self.map_tensor_name(name), data_torch)]
  870. @Model.register("GPTBigCodeForCausalLM")
  871. class StarCoderModel(Model):
  872. model_arch = gguf.MODEL_ARCH.STARCODER
  873. def set_gguf_parameters(self):
  874. block_count = self.hparams["n_layer"]
  875. self.gguf_writer.add_name("StarCoder")
  876. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  877. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  878. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  879. self.gguf_writer.add_block_count(block_count)
  880. self.gguf_writer.add_head_count(self.hparams["n_head"])
  881. self.gguf_writer.add_head_count_kv(1)
  882. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  883. self.gguf_writer.add_file_type(self.ftype)
  884. @Model.register("GPTRefactForCausalLM")
  885. class RefactModel(Model):
  886. model_arch = gguf.MODEL_ARCH.REFACT
  887. def set_vocab(self):
  888. super().set_vocab()
  889. # TODO: how to determine special FIM tokens automatically?
  890. special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
  891. special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
  892. special_vocab._set_special_token("prefix", 1)
  893. special_vocab._set_special_token("suffix", 3)
  894. special_vocab._set_special_token("middle", 2)
  895. special_vocab._set_special_token("fsep", 4) # is this correct?
  896. special_vocab.add_to_gguf(self.gguf_writer)
  897. def set_gguf_parameters(self):
  898. hidden_dim = self.hparams["n_embd"]
  899. inner_dim = 4 * hidden_dim
  900. hidden_dim = int(2 * inner_dim / 3)
  901. multiple_of = 256
  902. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  903. block_count = self.hparams["n_layer"]
  904. self.gguf_writer.add_name("Refact")
  905. # refact uses Alibi. So this is from config.json which might be used by training.
  906. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  907. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  908. self.gguf_writer.add_feed_forward_length(ff_dim)
  909. self.gguf_writer.add_block_count(block_count)
  910. self.gguf_writer.add_head_count(self.hparams["n_head"])
  911. self.gguf_writer.add_head_count_kv(1)
  912. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  913. self.gguf_writer.add_file_type(self.ftype)
  914. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  915. hidden_dim = self.hparams["n_embd"]
  916. inner_dim = 4 * hidden_dim
  917. hidden_dim = int(2 * inner_dim / 3)
  918. multiple_of = 256
  919. ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
  920. n_head = self.hparams["n_head"]
  921. n_head_kv = 1
  922. head_dim = self.hparams["n_embd"] // n_head
  923. tensors: list[tuple[str, Tensor]] = []
  924. if bid is not None:
  925. if name == f"transformer.h.{bid}.attn.kv.weight":
  926. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim]))
  927. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:]))
  928. elif name == f"transformer.h.{bid}.attn.q.weight":
  929. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch))
  930. elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight":
  931. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]))
  932. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]))
  933. if len(tensors) == 0:
  934. tensors.append((self.map_tensor_name(name), data_torch))
  935. return tensors
  936. @Model.register("PersimmonForCausalLM")
  937. class PersimmonModel(Model):
  938. model_arch = gguf.MODEL_ARCH.PERSIMMON
  939. def set_gguf_parameters(self):
  940. block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
  941. head_count = self.hparams["num_attention_heads"]
  942. head_count_kv = head_count
  943. hidden_size = self.hparams["hidden_size"]
  944. self.gguf_writer.add_name('persimmon-8b-chat')
  945. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  946. self.gguf_writer.add_embedding_length(hidden_size)
  947. self.gguf_writer.add_block_count(block_count)
  948. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  949. # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
  950. # than the head size?
  951. # ref: https://github.com/ggerganov/llama.cpp/pull/4889
  952. # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
  953. self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
  954. self.gguf_writer.add_head_count(head_count)
  955. self.gguf_writer.add_head_count_kv(head_count_kv)
  956. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  957. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
  958. def set_vocab(self):
  959. self._set_vocab_sentencepiece()
  960. # self.gguf_writer.add_bos_token_id(71013)
  961. # self.gguf_writer.add_eos_token_id(71013)
  962. def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
  963. del name, new_name, bid, n_dims # unused
  964. # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
  965. return True
  966. @Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
  967. class StableLMModel(Model):
  968. model_arch = gguf.MODEL_ARCH.STABLELM
  969. def set_vocab(self):
  970. if (self.dir_model / "tokenizer.json").is_file():
  971. self._set_vocab_gpt2()
  972. else:
  973. # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
  974. self._set_vocab_qwen()
  975. def set_gguf_parameters(self):
  976. hparams = self.hparams
  977. block_count = hparams["num_hidden_layers"]
  978. self.gguf_writer.add_name(self.dir_model.name)
  979. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  980. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  981. self.gguf_writer.add_block_count(block_count)
  982. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  983. rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
  984. self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
  985. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  986. self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
  987. self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
  988. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
  989. self.gguf_writer.add_file_type(self.ftype)
  990. _q_norms: list[dict[str, Tensor]] | None = None
  991. _k_norms: list[dict[str, Tensor]] | None = None
  992. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  993. n_head = self.hparams["num_attention_heads"]
  994. n_kv_head = self.hparams["num_key_value_heads"]
  995. if name.find("q_layernorm.norms") != -1:
  996. assert bid is not None
  997. if self._q_norms is None:
  998. self._q_norms = [{} for _ in range(self.block_count)]
  999. self._q_norms[bid][name] = data_torch
  1000. if len(self._q_norms[bid]) >= n_head:
  1001. return self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm")
  1002. else:
  1003. return []
  1004. if name.find("k_layernorm.norms") != -1:
  1005. assert bid is not None
  1006. if self._k_norms is None:
  1007. self._k_norms = [{} for _ in range(self.block_count)]
  1008. self._k_norms[bid][name] = data_torch
  1009. if len(self._k_norms[bid]) >= n_kv_head:
  1010. return self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm")
  1011. else:
  1012. return []
  1013. return [(self.map_tensor_name(name), data_torch)]
  1014. def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm"):
  1015. datas: list[Tensor] = []
  1016. # extract the norms in order
  1017. for xid in range(n_head):
  1018. ename = f"model.layers.{bid}.self_attn.{layer_name}.norms.{xid}.weight"
  1019. datas.append(norms[ename])
  1020. del norms[ename]
  1021. data_torch = torch.stack(datas, dim=0)
  1022. merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
  1023. new_name = self.map_tensor_name(merged_name)
  1024. return [(new_name, data_torch)]
  1025. def write_tensors(self):
  1026. super().write_tensors()
  1027. if self._q_norms is not None or self._k_norms is not None:
  1028. # flatten two `list[dict[str, Tensor]]` into a single `list[str]`
  1029. norms = (
  1030. [k for d in self._q_norms for k in d.keys()] if self._q_norms is not None else []
  1031. ) + (
  1032. [k for d in self._k_norms for k in d.keys()] if self._k_norms is not None else []
  1033. )
  1034. if len(norms) > 0:
  1035. raise ValueError(f"Unprocessed norms: {norms}")
  1036. @Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
  1037. class LlamaModel(Model):
  1038. model_arch = gguf.MODEL_ARCH.LLAMA
  1039. def set_vocab(self):
  1040. try:
  1041. self. _set_vocab_sentencepiece()
  1042. except FileNotFoundError:
  1043. try:
  1044. self._set_vocab_llama_hf()
  1045. except (FileNotFoundError, TypeError):
  1046. # Llama 3
  1047. self._set_vocab_gpt2()
  1048. # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
  1049. if self.hparams.get("vocab_size", 32000) == 32016:
  1050. special_vocab = gguf.SpecialVocab(
  1051. self.dir_model, load_merges=False,
  1052. special_token_types = ['prefix', 'suffix', 'middle', 'eot']
  1053. )
  1054. special_vocab._set_special_token("prefix", 32007)
  1055. special_vocab._set_special_token("suffix", 32008)
  1056. special_vocab._set_special_token("middle", 32009)
  1057. special_vocab._set_special_token("eot", 32010)
  1058. special_vocab.add_to_gguf(self.gguf_writer)
  1059. def set_gguf_parameters(self):
  1060. super().set_gguf_parameters()
  1061. hparams = self.hparams
  1062. self.gguf_writer.add_vocab_size(hparams["vocab_size"])
  1063. self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
  1064. if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
  1065. if self.hparams["rope_scaling"].get("type") == "linear":
  1066. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1067. self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
  1068. @staticmethod
  1069. def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
  1070. if n_head_kv is not None and n_head != n_head_kv:
  1071. n_head = n_head_kv
  1072. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1073. .swapaxes(1, 2)
  1074. .reshape(weights.shape))
  1075. _experts: list[dict[str, Tensor]] | None = None
  1076. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1077. n_head = self.hparams["num_attention_heads"]
  1078. n_kv_head = self.hparams.get("num_key_value_heads")
  1079. if name.endswith("q_proj.weight"):
  1080. data_torch = LlamaModel.permute(data_torch, n_head, n_head)
  1081. if name.endswith("k_proj.weight"):
  1082. data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
  1083. # process the experts separately
  1084. if name.find("block_sparse_moe.experts") != -1:
  1085. n_experts = self.hparams["num_local_experts"]
  1086. assert bid is not None
  1087. if self._experts is None:
  1088. self._experts = [{} for _ in range(self.block_count)]
  1089. self._experts[bid][name] = data_torch
  1090. if len(self._experts[bid]) >= n_experts * 3:
  1091. tensors: list[tuple[str, Tensor]] = []
  1092. # merge the experts into a single 3d tensor
  1093. for wid in ["w1", "w2", "w3"]:
  1094. datas: list[Tensor] = []
  1095. for xid in range(n_experts):
  1096. ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
  1097. datas.append(self._experts[bid][ename])
  1098. del self._experts[bid][ename]
  1099. data_torch = torch.stack(datas, dim=0)
  1100. merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"
  1101. new_name = self.map_tensor_name(merged_name)
  1102. tensors.append((new_name, data_torch))
  1103. return tensors
  1104. else:
  1105. return []
  1106. return [(self.map_tensor_name(name), data_torch)]
  1107. def write_tensors(self):
  1108. super().write_tensors()
  1109. if self._experts is not None:
  1110. # flatten `list[dict[str, Tensor]]` into `list[str]`
  1111. experts = [k for d in self._experts for k in d.keys()]
  1112. if len(experts) > 0:
  1113. raise ValueError(f"Unprocessed experts: {experts}")
  1114. @Model.register("GrokForCausalLM")
  1115. class GrokModel(Model):
  1116. model_arch = gguf.MODEL_ARCH.GROK
  1117. def set_vocab(self):
  1118. self._set_vocab_sentencepiece()
  1119. def __init__(self, *args, **kwargs):
  1120. super().__init__(*args, **kwargs)
  1121. def set_gguf_parameters(self):
  1122. super().set_gguf_parameters()
  1123. self.gguf_writer.add_name("Grok")
  1124. _experts: list[dict[str, Tensor]] | None = None
  1125. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1126. # process the experts separately
  1127. if name.find(".moe.") != -1:
  1128. n_experts = self.hparams["num_local_experts"]
  1129. assert bid is not None
  1130. if self._experts is None:
  1131. self._experts = [{} for _ in range(self.block_count)]
  1132. self._experts[bid][name] = data_torch
  1133. if len(self._experts[bid]) >= n_experts * 3:
  1134. tensors: list[tuple[str, Tensor]] = []
  1135. # merge the experts into a single 3d tensor
  1136. for wid in ["linear", "linear_1", "linear_v"]:
  1137. datas: list[Tensor] = []
  1138. for xid in range(n_experts):
  1139. ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
  1140. datas.append(self._experts[bid][ename])
  1141. del self._experts[bid][ename]
  1142. data_torch = torch.stack(datas, dim=0)
  1143. merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
  1144. new_name = self.map_tensor_name(merged_name)
  1145. tensors.append((new_name, data_torch))
  1146. return tensors
  1147. else:
  1148. return []
  1149. return [(self.map_tensor_name(name), data_torch)]
  1150. @Model.register("DbrxForCausalLM")
  1151. class DbrxModel(Model):
  1152. model_arch = gguf.MODEL_ARCH.DBRX
  1153. def set_gguf_parameters(self):
  1154. ffn_config = self.hparams["ffn_config"]
  1155. attn_config = self.hparams["attn_config"]
  1156. self.gguf_writer.add_name(self.hparams["model_type"])
  1157. self.gguf_writer.add_block_count(self.hparams["n_layers"])
  1158. self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
  1159. self.gguf_writer.add_embedding_length(self.hparams["d_model"])
  1160. self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
  1161. self.gguf_writer.add_head_count(self.hparams["n_heads"])
  1162. self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
  1163. self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
  1164. self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
  1165. self.gguf_writer.add_file_type(self.ftype)
  1166. self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
  1167. self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
  1168. self.gguf_writer.add_layer_norm_eps(1e-5)
  1169. self.gguf_writer.add_file_type(self.ftype)
  1170. logger.info(f"gguf: file type = {self.ftype}")
  1171. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1172. del bid # unused
  1173. n_expert = self.hparams["ffn_config"]["moe_num_experts"]
  1174. n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
  1175. n_embd = self.hparams["d_model"]
  1176. # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
  1177. # original implementation expects (n_expert, n_ff, n_embd) for all experts weights
  1178. # But llama.cpp moe graph works differently
  1179. # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
  1180. # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
  1181. exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
  1182. "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
  1183. "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
  1184. experts = False
  1185. for exp_tensor_name in exp_tensor_names.keys():
  1186. if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
  1187. experts = True
  1188. data_torch = data_torch.view(n_expert, n_ff, n_embd)
  1189. if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
  1190. data_torch = data_torch.permute(*permute_tensor)
  1191. break
  1192. # map tensor names
  1193. # In MoE models the ffn tensors are typically most of the model weights,
  1194. # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
  1195. # Every other model has the weight names ending in .weight,
  1196. # let's assume that is the convention which is not the case for dbrx:
  1197. # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
  1198. new_name = self.map_tensor_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
  1199. return [(new_name, data_torch)]
  1200. def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
  1201. del name, new_name, bid # unused
  1202. return n_dims > 1
  1203. @Model.register("MiniCPMForCausalLM")
  1204. class MiniCPMModel(Model):
  1205. model_arch = gguf.MODEL_ARCH.MINICPM
  1206. def set_gguf_parameters(self):
  1207. block_count = self.hparams["num_hidden_layers"]
  1208. self.gguf_writer.add_name("MiniCPM")
  1209. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1210. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1211. self.gguf_writer.add_block_count(block_count)
  1212. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1213. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  1214. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1215. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1216. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1217. self.gguf_writer.add_file_type(self.ftype)
  1218. def set_vocab(self):
  1219. self._set_vocab_llama_hf()
  1220. def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
  1221. if n_kv_head is not None and n_head != n_kv_head:
  1222. n_head //= n_kv_head
  1223. return (
  1224. weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1225. .swapaxes(1, 2)
  1226. .reshape(weights.shape)
  1227. )
  1228. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1229. del bid # unused
  1230. n_head = self.hparams["num_attention_heads"]
  1231. n_kv_head = self.hparams.get("num_key_value_heads")
  1232. # HF models permute some of the tensors, so we need to undo that
  1233. if name.endswith(("q_proj.weight")):
  1234. data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
  1235. if name.endswith(("k_proj.weight")):
  1236. data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
  1237. return [(self.map_tensor_name(name), data_torch)]
  1238. @Model.register("QWenLMHeadModel")
  1239. class QwenModel(Model):
  1240. model_arch = gguf.MODEL_ARCH.QWEN
  1241. @staticmethod
  1242. def token_bytes_to_string(b):
  1243. from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
  1244. byte_encoder = bytes_to_unicode()
  1245. return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
  1246. @staticmethod
  1247. def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
  1248. parts = [bytes([b]) for b in token]
  1249. while True:
  1250. min_idx = None
  1251. min_rank = None
  1252. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  1253. rank = mergeable_ranks.get(pair[0] + pair[1])
  1254. if rank is not None and (min_rank is None or rank < min_rank):
  1255. min_idx = i
  1256. min_rank = rank
  1257. if min_rank is None or (max_rank is not None and min_rank >= max_rank):
  1258. break
  1259. assert min_idx is not None
  1260. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
  1261. return parts
  1262. def set_vocab(self):
  1263. self._set_vocab_qwen()
  1264. def set_gguf_parameters(self):
  1265. self.gguf_writer.add_name("Qwen")
  1266. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1267. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1268. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1269. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1270. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  1271. self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
  1272. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1273. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
  1274. self.gguf_writer.add_file_type(self.ftype)
  1275. @Model.register("Qwen2ForCausalLM")
  1276. class Qwen2Model(Model):
  1277. model_arch = gguf.MODEL_ARCH.QWEN2
  1278. def set_vocab(self):
  1279. try:
  1280. self._set_vocab_sentencepiece()
  1281. except FileNotFoundError:
  1282. self._set_vocab_gpt2()
  1283. @Model.register("Qwen2MoeForCausalLM")
  1284. class Qwen2MoeModel(Model):
  1285. model_arch = gguf.MODEL_ARCH.QWEN2MOE
  1286. def set_gguf_parameters(self):
  1287. super().set_gguf_parameters()
  1288. if (n_experts := self.hparams.get("num_experts")) is not None:
  1289. self.gguf_writer.add_expert_count(n_experts)
  1290. _experts: list[dict[str, Tensor]] | None = None
  1291. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1292. # process the experts separately
  1293. if name.find("experts") != -1:
  1294. n_experts = self.hparams["num_experts"]
  1295. assert bid is not None
  1296. if self._experts is None:
  1297. self._experts = [{} for _ in range(self.block_count)]
  1298. self._experts[bid][name] = data_torch
  1299. if len(self._experts[bid]) >= n_experts * 3:
  1300. tensors: list[tuple[str, Tensor]] = []
  1301. # merge the experts into a single 3d tensor
  1302. for w_name in ["down_proj", "gate_proj", "up_proj"]:
  1303. datas: list[Tensor] = []
  1304. for xid in range(n_experts):
  1305. ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
  1306. datas.append(self._experts[bid][ename])
  1307. del self._experts[bid][ename]
  1308. data_torch = torch.stack(datas, dim=0)
  1309. merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
  1310. new_name = self.map_tensor_name(merged_name)
  1311. tensors.append((new_name, data_torch))
  1312. return tensors
  1313. else:
  1314. return []
  1315. return [(self.map_tensor_name(name), data_torch)]
  1316. def write_tensors(self):
  1317. super().write_tensors()
  1318. if self._experts is not None:
  1319. # flatten `list[dict[str, Tensor]]` into `list[str]`
  1320. experts = [k for d in self._experts for k in d.keys()]
  1321. if len(experts) > 0:
  1322. raise ValueError(f"Unprocessed experts: {experts}")
  1323. @Model.register("GPT2LMHeadModel")
  1324. class GPT2Model(Model):
  1325. model_arch = gguf.MODEL_ARCH.GPT2
  1326. def set_gguf_parameters(self):
  1327. self.gguf_writer.add_name(self.dir_model.name)
  1328. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1329. self.gguf_writer.add_context_length(self.hparams["n_ctx"])
  1330. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1331. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1332. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1333. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1334. self.gguf_writer.add_file_type(self.ftype)
  1335. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1336. del bid # unused
  1337. tensors: list[tuple[str, Tensor]] = []
  1338. # we don't need these
  1339. if name.endswith((".attn.bias", ".attn.masked_bias")):
  1340. return tensors
  1341. if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
  1342. data_torch = data_torch.transpose(1, 0)
  1343. new_name = self.map_tensor_name(name)
  1344. tensors.append((new_name, data_torch))
  1345. # note: GPT2 output is tied to (same as) wte in original model
  1346. if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
  1347. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
  1348. return tensors
  1349. @Model.register("PhiForCausalLM")
  1350. class Phi2Model(Model):
  1351. model_arch = gguf.MODEL_ARCH.PHI2
  1352. def set_gguf_parameters(self):
  1353. block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
  1354. rot_pct = self.find_hparam(["partial_rotary_factor"])
  1355. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  1356. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  1357. self.gguf_writer.add_name("Phi2")
  1358. self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
  1359. self.gguf_writer.add_embedding_length(n_embd)
  1360. self.gguf_writer.add_feed_forward_length(4 * n_embd)
  1361. self.gguf_writer.add_block_count(block_count)
  1362. self.gguf_writer.add_head_count(n_head)
  1363. self.gguf_writer.add_head_count_kv(n_head)
  1364. self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
  1365. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1366. self.gguf_writer.add_file_type(self.ftype)
  1367. self.gguf_writer.add_add_bos_token(False)
  1368. @Model.register("Phi3ForCausalLM")
  1369. class Phi3MiniModel(Model):
  1370. model_arch = gguf.MODEL_ARCH.PHI3
  1371. def set_vocab(self):
  1372. from sentencepiece import SentencePieceProcessor
  1373. tokenizer_path = self.dir_model / 'tokenizer.model'
  1374. if not tokenizer_path.is_file():
  1375. raise ValueError(f'Error: Missing {tokenizer_path}')
  1376. tokenizer = SentencePieceProcessor()
  1377. tokenizer.LoadFromFile(str(tokenizer_path))
  1378. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1379. tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
  1380. scores: list[float] = [-10000.0] * vocab_size
  1381. toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
  1382. for token_id in range(tokenizer.vocab_size()):
  1383. piece = tokenizer.IdToPiece(token_id)
  1384. text = piece.encode("utf-8")
  1385. score = tokenizer.GetScore(token_id)
  1386. toktype = SentencePieceTokenTypes.NORMAL
  1387. if tokenizer.IsUnknown(token_id):
  1388. toktype = SentencePieceTokenTypes.UNKNOWN
  1389. elif tokenizer.IsControl(token_id):
  1390. toktype = SentencePieceTokenTypes.CONTROL
  1391. elif tokenizer.IsUnused(token_id):
  1392. toktype = SentencePieceTokenTypes.UNUSED
  1393. elif tokenizer.IsByte(token_id):
  1394. toktype = SentencePieceTokenTypes.BYTE
  1395. tokens[token_id] = text
  1396. scores[token_id] = score
  1397. toktypes[token_id] = toktype
  1398. added_tokens_file = self.dir_model / 'added_tokens.json'
  1399. if added_tokens_file.is_file():
  1400. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1401. added_tokens_json = json.load(f)
  1402. for key in added_tokens_json:
  1403. token_id = added_tokens_json[key]
  1404. if (token_id >= vocab_size):
  1405. logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
  1406. continue
  1407. tokens[token_id] = key.encode("utf-8")
  1408. scores[token_id] = -1000.0
  1409. toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
  1410. self.gguf_writer.add_tokenizer_model("llama")
  1411. self.gguf_writer.add_tokenizer_pre("default")
  1412. self.gguf_writer.add_token_list(tokens)
  1413. self.gguf_writer.add_token_scores(scores)
  1414. self.gguf_writer.add_token_types(toktypes)
  1415. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1416. special_vocab.add_to_gguf(self.gguf_writer)
  1417. def set_gguf_parameters(self):
  1418. block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
  1419. rot_pct = 1.0
  1420. n_embd = self.find_hparam(["hidden_size", "n_embd"])
  1421. n_head = self.find_hparam(["num_attention_heads", "n_head"])
  1422. rms_eps = self.find_hparam(["rms_norm_eps"])
  1423. self.gguf_writer.add_name("Phi3")
  1424. self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
  1425. self.gguf_writer.add_embedding_length(n_embd)
  1426. self.gguf_writer.add_feed_forward_length(8192)
  1427. self.gguf_writer.add_block_count(block_count)
  1428. self.gguf_writer.add_head_count(n_head)
  1429. self.gguf_writer.add_head_count_kv(n_head)
  1430. self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
  1431. self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
  1432. self.gguf_writer.add_file_type(self.ftype)
  1433. @Model.register("PlamoForCausalLM")
  1434. class PlamoModel(Model):
  1435. model_arch = gguf.MODEL_ARCH.PLAMO
  1436. def set_vocab(self):
  1437. self._set_vocab_sentencepiece()
  1438. def set_gguf_parameters(self):
  1439. hparams = self.hparams
  1440. block_count = hparams["num_hidden_layers"]
  1441. self.gguf_writer.add_name("PLaMo")
  1442. self.gguf_writer.add_context_length(4096) # not in config.json
  1443. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1444. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1445. self.gguf_writer.add_block_count(block_count)
  1446. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1447. self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
  1448. self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
  1449. self.gguf_writer.add_file_type(self.ftype)
  1450. def shuffle_attn_q_weight(self, data_torch):
  1451. assert data_torch.size() == (5120, 5120)
  1452. data_torch = data_torch.reshape(8, 5, 128, 5120)
  1453. data_torch = torch.permute(data_torch, (1, 0, 2, 3))
  1454. data_torch = torch.reshape(data_torch, (5120, 5120))
  1455. return data_torch
  1456. def shuffle_attn_output_weight(self, data_torch):
  1457. assert data_torch.size() == (5120, 5120)
  1458. data_torch = data_torch.reshape(5120, 8, 5, 128)
  1459. data_torch = torch.permute(data_torch, (0, 2, 1, 3))
  1460. data_torch = torch.reshape(data_torch, (5120, 5120))
  1461. return data_torch
  1462. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1463. del bid # unused
  1464. new_name = self.map_tensor_name(name)
  1465. # shuffle for broadcasting of gqa in ggml_mul_mat
  1466. if new_name.endswith("attn_q.weight"):
  1467. data_torch = self.shuffle_attn_q_weight(data_torch)
  1468. elif new_name.endswith("attn_output.weight"):
  1469. data_torch = self.shuffle_attn_output_weight(data_torch)
  1470. return [(new_name, data_torch)]
  1471. @Model.register("CodeShellForCausalLM")
  1472. class CodeShellModel(Model):
  1473. model_arch = gguf.MODEL_ARCH.CODESHELL
  1474. def set_gguf_parameters(self):
  1475. block_count = self.hparams["n_layer"]
  1476. self.gguf_writer.add_name("CodeShell")
  1477. self.gguf_writer.add_context_length(self.hparams["n_positions"])
  1478. self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
  1479. self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
  1480. self.gguf_writer.add_block_count(block_count)
  1481. self.gguf_writer.add_head_count(self.hparams["n_head"])
  1482. self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
  1483. self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
  1484. self.gguf_writer.add_file_type(self.ftype)
  1485. self.gguf_writer.add_rope_freq_base(10000.0)
  1486. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
  1487. self.gguf_writer.add_rope_scaling_factor(1.0)
  1488. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1489. del bid # unused
  1490. new_name = self.map_tensor_name(name)
  1491. tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
  1492. if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
  1493. assert self.tensor_names is not None
  1494. if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
  1495. # copy tok_embd.weight to output.weight
  1496. tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
  1497. return tensors
  1498. @Model.register("InternLM2ForCausalLM")
  1499. class InternLM2Model(Model):
  1500. model_arch = gguf.MODEL_ARCH.INTERNLM2
  1501. def set_vocab(self):
  1502. # (TODO): Is there a better way?
  1503. # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
  1504. # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
  1505. # recognized as an empty string in C++.
  1506. from sentencepiece import SentencePieceProcessor
  1507. from sentencepiece import sentencepiece_model_pb2 as model
  1508. tokenizer_path = self.dir_model / 'tokenizer.model'
  1509. tokens: list[bytes] = []
  1510. scores: list[float] = []
  1511. toktypes: list[int] = []
  1512. if not tokenizer_path.is_file():
  1513. logger.error(f'Error: Missing {tokenizer_path}')
  1514. sys.exit(1)
  1515. sentencepiece_model = model.ModelProto()
  1516. sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
  1517. add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
  1518. tokenizer = SentencePieceProcessor()
  1519. tokenizer.LoadFromFile(str(tokenizer_path))
  1520. vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
  1521. for token_id in range(vocab_size):
  1522. piece = tokenizer.IdToPiece(token_id)
  1523. text = piece.encode("utf-8")
  1524. score = tokenizer.GetScore(token_id)
  1525. if text == b"\x00":
  1526. # (TODO): fixme
  1527. # Hack here and replace the \x00 characters.
  1528. logger.warning(f"InternLM2 convert token '{text}' to '🐉'!")
  1529. text = "🐉".encode("utf-8")
  1530. toktype = SentencePieceTokenTypes.NORMAL
  1531. if tokenizer.IsUnknown(token_id):
  1532. toktype = SentencePieceTokenTypes.UNKNOWN
  1533. elif tokenizer.IsControl(token_id):
  1534. toktype = SentencePieceTokenTypes.CONTROL
  1535. elif tokenizer.IsUnused(token_id):
  1536. toktype = SentencePieceTokenTypes.UNUSED
  1537. elif tokenizer.IsByte(token_id):
  1538. toktype = SentencePieceTokenTypes.BYTE
  1539. tokens.append(text)
  1540. scores.append(score)
  1541. toktypes.append(toktype)
  1542. added_tokens_file = self.dir_model / 'added_tokens.json'
  1543. if added_tokens_file.is_file():
  1544. with open(added_tokens_file, "r", encoding="utf-8") as f:
  1545. added_tokens_json = json.load(f)
  1546. for key in added_tokens_json:
  1547. tokens.append(key.encode("utf-8"))
  1548. scores.append(-1000.0)
  1549. toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
  1550. self.gguf_writer.add_tokenizer_model("llama")
  1551. self.gguf_writer.add_tokenizer_pre("default")
  1552. self.gguf_writer.add_token_list(tokens)
  1553. self.gguf_writer.add_token_scores(scores)
  1554. self.gguf_writer.add_token_types(toktypes)
  1555. self.gguf_writer.add_add_space_prefix(add_prefix)
  1556. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1557. old_eos = special_vocab.special_token_ids["eos"]
  1558. if "chat" in os.path.basename(self.dir_model.absolute()):
  1559. # For the chat model, we replace the eos with '<|im_end|>'.
  1560. # TODO: this is a hack, should be fixed
  1561. # https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
  1562. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
  1563. logger.warning(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
  1564. in chat mode so that the conversation can end normally.")
  1565. special_vocab.add_to_gguf(self.gguf_writer)
  1566. def _try_get_sft_eos(self, tokenizer):
  1567. unused_145_list = tokenizer.Encode('[UNUSED_TOKEN_145]')
  1568. im_end_list = tokenizer.Encode('<|im_end|>')
  1569. eos_token = None
  1570. assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
  1571. if len(unused_145_list) == 1:
  1572. eos_token = unused_145_list[0]
  1573. if len(im_end_list) == 1:
  1574. eos_token = im_end_list[0]
  1575. assert eos_token
  1576. return eos_token
  1577. def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
  1578. if n_head_kv is not None and n_head != n_head_kv:
  1579. n_head = n_head_kv
  1580. return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
  1581. .swapaxes(1, 2)
  1582. .reshape(weights.shape))
  1583. def set_gguf_parameters(self):
  1584. self.gguf_writer.add_name("InternLM2")
  1585. self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
  1586. self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
  1587. self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
  1588. self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
  1589. self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
  1590. self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
  1591. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1592. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
  1593. self.gguf_writer.add_file_type(self.ftype)
  1594. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1595. num_heads = self.hparams["num_attention_heads"]
  1596. num_kv_heads = self.hparams["num_key_value_heads"]
  1597. hidden_size = self.hparams["hidden_size"]
  1598. q_per_kv = num_heads // num_kv_heads
  1599. head_dim = hidden_size // num_heads
  1600. num_groups = num_heads // q_per_kv
  1601. qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
  1602. if re.match(qkv_pattern, name):
  1603. bid = re.findall(qkv_pattern, name)[0]
  1604. qkv = data_torch
  1605. # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
  1606. qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
  1607. q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
  1608. # The model weights of q and k equire additional reshape.
  1609. # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
  1610. q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
  1611. # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
  1612. k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
  1613. # v = rearrange(v, " o g n i -> o (g n i)").T
  1614. v = v.reshape((v.shape[0], -1)).T
  1615. return [
  1616. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
  1617. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
  1618. (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v),
  1619. ]
  1620. else:
  1621. return [(self.map_tensor_name(name), data_torch)]
  1622. @Model.register("BertModel", "CamembertModel")
  1623. class BertModel(Model):
  1624. model_arch = gguf.MODEL_ARCH.BERT
  1625. def __init__(self, *args, **kwargs):
  1626. super().__init__(*args, **kwargs)
  1627. self.vocab_size = None
  1628. def set_gguf_parameters(self):
  1629. super().set_gguf_parameters()
  1630. self.gguf_writer.add_causal_attention(False)
  1631. # get pooling path
  1632. pooling_path = None
  1633. module_path = self.dir_model / "modules.json"
  1634. if module_path.is_file():
  1635. with open(module_path, encoding="utf-8") as f:
  1636. modules = json.load(f)
  1637. for mod in modules:
  1638. if mod["type"] == "sentence_transformers.models.Pooling":
  1639. pooling_path = mod["path"]
  1640. break
  1641. # get pooling type
  1642. if pooling_path is not None:
  1643. with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
  1644. pooling = json.load(f)
  1645. if pooling["pooling_mode_mean_tokens"]:
  1646. pooling_type = gguf.PoolingType.MEAN
  1647. elif pooling["pooling_mode_cls_token"]:
  1648. pooling_type = gguf.PoolingType.CLS
  1649. else:
  1650. raise NotImplementedError("Only MEAN and CLS pooling types supported")
  1651. self.gguf_writer.add_pooling_type(pooling_type)
  1652. def set_vocab(self):
  1653. tokens, toktypes, tokpre = self.get_vocab_base()
  1654. self.vocab_size = len(tokens)
  1655. # we need this to validate the size of the token_type embeddings
  1656. # though currently we are passing all zeros to the token_type embeddings
  1657. self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B"
  1658. # convert to phantom space vocab
  1659. def phantom(tok):
  1660. if tok.startswith("[") and tok.endswith("]"):
  1661. return tok
  1662. if tok.startswith("##"):
  1663. return tok[2:]
  1664. return "\u2581" + tok
  1665. tokens = list(map(phantom, tokens))
  1666. # add vocab to gguf
  1667. self.gguf_writer.add_tokenizer_model("bert")
  1668. self.gguf_writer.add_tokenizer_pre(tokpre)
  1669. self.gguf_writer.add_token_list(tokens)
  1670. self.gguf_writer.add_token_types(toktypes)
  1671. # handle special tokens
  1672. special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
  1673. special_vocab.add_to_gguf(self.gguf_writer)
  1674. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1675. del bid # unused
  1676. # we are only using BERT for embeddings so we don't need the pooling layer
  1677. if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
  1678. return [] # we don't need these
  1679. return [(self.map_tensor_name(name), data_torch)]
  1680. @Model.register("NomicBertModel")
  1681. class NomicBertModel(BertModel):
  1682. model_arch = gguf.MODEL_ARCH.NOMIC_BERT
  1683. def __init__(self, *args, **kwargs):
  1684. super().__init__(*args, **kwargs)
  1685. # the HF config claims n_ctx=8192, but it uses RoPE scaling
  1686. self.hparams["n_ctx"] = 2048
  1687. # SwigLU activation
  1688. assert self.hparams["activation_function"] == "swiglu"
  1689. # this doesn't do anything in the HF version
  1690. assert self.hparams["causal"] is False
  1691. # no bias tensors
  1692. assert self.hparams["qkv_proj_bias"] is False
  1693. assert self.hparams["mlp_fc1_bias"] is False
  1694. assert self.hparams["mlp_fc2_bias"] is False
  1695. # norm at end of layer
  1696. assert self.hparams["prenorm"] is False
  1697. # standard RoPE
  1698. assert self.hparams["rotary_emb_fraction"] == 1.0
  1699. assert self.hparams["rotary_emb_interleaved"] is False
  1700. assert self.hparams["rotary_emb_scale_base"] is None
  1701. def set_gguf_parameters(self):
  1702. super().set_gguf_parameters()
  1703. self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
  1704. @Model.register("GemmaForCausalLM")
  1705. class GemmaModel(Model):
  1706. model_arch = gguf.MODEL_ARCH.GEMMA
  1707. def set_vocab(self):
  1708. self._set_vocab_sentencepiece()
  1709. # TODO: these special tokens should be exported only for the CodeGemma family
  1710. special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
  1711. special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
  1712. special_vocab._set_special_token("prefix", 67)
  1713. special_vocab._set_special_token("suffix", 69)
  1714. special_vocab._set_special_token("middle", 68)
  1715. special_vocab._set_special_token("fsep", 70)
  1716. special_vocab._set_special_token("eot", 107)
  1717. special_vocab.add_to_gguf(self.gguf_writer)
  1718. def set_gguf_parameters(self):
  1719. hparams = self.hparams
  1720. block_count = hparams["num_hidden_layers"]
  1721. self.gguf_writer.add_name(self.dir_model.name)
  1722. self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
  1723. self.gguf_writer.add_embedding_length(hparams["hidden_size"])
  1724. self.gguf_writer.add_block_count(block_count)
  1725. self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
  1726. self.gguf_writer.add_head_count(hparams["num_attention_heads"])
  1727. self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
  1728. self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
  1729. self.gguf_writer.add_key_length(hparams["head_dim"])
  1730. self.gguf_writer.add_value_length(hparams["head_dim"])
  1731. self.gguf_writer.add_file_type(self.ftype)
  1732. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1733. del bid # unused
  1734. # lm_head is not used in llama.cpp, while autoawq will include this tensor in model
  1735. # To prevent errors, skip loading lm_head.weight.
  1736. if name == "lm_head.weight":
  1737. logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
  1738. return []
  1739. # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
  1740. if name.endswith("norm.weight"):
  1741. data_torch = data_torch + 1
  1742. return [(self.map_tensor_name(name), data_torch)]
  1743. @Model.register("Starcoder2ForCausalLM")
  1744. class StarCoder2Model(Model):
  1745. model_arch = gguf.MODEL_ARCH.STARCODER2
  1746. @Model.register("MambaForCausalLM", "MambaLMHeadModel")
  1747. class MambaModel(Model):
  1748. model_arch = gguf.MODEL_ARCH.MAMBA
  1749. def set_vocab(self):
  1750. vocab_size = self.hparams["vocab_size"]
  1751. # Round vocab size to next multiple of 8
  1752. pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
  1753. # pad using ceiling division
  1754. # ref: https://stackoverflow.com/a/17511341/22827863
  1755. vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
  1756. self.hparams["vocab_size"] = vocab_size
  1757. if (self.dir_model / "tokenizer.json").is_file():
  1758. self._set_vocab_gpt2()
  1759. elif (self.dir_model / "tokenizer.model").is_file():
  1760. self._set_vocab_sentencepiece()
  1761. else:
  1762. # Use the GPT-NeoX tokenizer when no tokenizer files are present
  1763. tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
  1764. logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
  1765. neox_reader = gguf.GGUFReader(tokenizer_path, "r")
  1766. field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
  1767. self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8") if field else "gpt2")
  1768. field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE)
  1769. self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else "mpt")
  1770. field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
  1771. assert field
  1772. self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
  1773. field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
  1774. assert field
  1775. self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
  1776. field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
  1777. assert field
  1778. self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
  1779. field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
  1780. self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
  1781. field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
  1782. self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
  1783. field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
  1784. self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
  1785. field = neox_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)
  1786. self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0] if field else 0)
  1787. def set_gguf_parameters(self):
  1788. d_model = self.find_hparam(["hidden_size", "d_model"])
  1789. d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
  1790. d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
  1791. d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
  1792. # ceiling division
  1793. # ref: https://stackoverflow.com/a/17511341/22827863
  1794. # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
  1795. dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
  1796. rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
  1797. # Fail early for models which don't have a block expansion factor of 2
  1798. assert d_inner == 2 * d_model
  1799. self.gguf_writer.add_name(self.dir_model.name)
  1800. self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
  1801. self.gguf_writer.add_embedding_length(d_model)
  1802. self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
  1803. self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
  1804. self.gguf_writer.add_block_count(self.hparams["n_layer"])
  1805. self.gguf_writer.add_ssm_conv_kernel(d_conv)
  1806. self.gguf_writer.add_ssm_inner_size(d_inner)
  1807. self.gguf_writer.add_ssm_state_size(d_state)
  1808. self.gguf_writer.add_ssm_time_step_rank(dt_rank)
  1809. self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
  1810. self.gguf_writer.add_file_type(self.ftype)
  1811. _tok_embd = None
  1812. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1813. del bid # unused
  1814. output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
  1815. tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
  1816. new_name = self.map_tensor_name(name)
  1817. if name.endswith(".A_log"):
  1818. logger.debug("A_log --> A ==> " + new_name)
  1819. data_torch = -torch.exp(data_torch)
  1820. # assuming token_embd.weight is seen before output.weight
  1821. if self._tok_embd is not None and new_name == output_name:
  1822. if torch.equal(self._tok_embd, data_torch):
  1823. logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting")
  1824. return []
  1825. elif new_name == tok_embd_name:
  1826. self._tok_embd = data_torch
  1827. return [(new_name, data_torch)]
  1828. def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
  1829. del n_dims # unused
  1830. return bid is not None and new_name in (
  1831. self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
  1832. gguf.MODEL_TENSOR.SSM_CONV1D,
  1833. gguf.MODEL_TENSOR.SSM_X,
  1834. gguf.MODEL_TENSOR.SSM_DT,
  1835. gguf.MODEL_TENSOR.SSM_A,
  1836. gguf.MODEL_TENSOR.SSM_D,
  1837. ]
  1838. )
  1839. @Model.register("CohereForCausalLM")
  1840. class CommandR2Model(Model):
  1841. model_arch = gguf.MODEL_ARCH.COMMAND_R
  1842. def __init__(self, *args, **kwargs):
  1843. super().__init__(*args, **kwargs)
  1844. # max_position_embeddings = 8192 in config.json but model was actually
  1845. # trained on 128k context length
  1846. self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]
  1847. def set_gguf_parameters(self):
  1848. super().set_gguf_parameters()
  1849. self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
  1850. self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
  1851. @Model.register("OlmoForCausalLM")
  1852. @Model.register("OLMoForCausalLM")
  1853. class OlmoModel(Model):
  1854. model_arch = gguf.MODEL_ARCH.OLMO
  1855. def set_gguf_parameters(self):
  1856. super().set_gguf_parameters()
  1857. self.gguf_writer.add_layer_norm_eps(1e-5)
  1858. clip_qkv = self.hparams.get("clip_qkv")
  1859. if clip_qkv is not None:
  1860. self.gguf_writer.add_clamp_kqv(clip_qkv)
  1861. # Same as super class, but permuting q_proj, k_proj
  1862. # Copied from: LlamaModel
  1863. def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
  1864. del bid # unused
  1865. n_head = self.hparams["num_attention_heads"]
  1866. n_kv_head = self.hparams.get("num_key_value_heads")
  1867. if name.endswith("q_proj.weight"):
  1868. data_torch = LlamaModel.permute(data_torch, n_head, n_head)
  1869. if name.endswith("k_proj.weight"):
  1870. data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
  1871. return [(self.map_tensor_name(name), data_torch)]
  1872. @Model.register("JinaBertModel", "JinaBertForMaskedLM")
  1873. class JinaBertV2Model(BertModel):
  1874. model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
  1875. def __init__(self, *args, **kwargs):
  1876. super().__init__(*args, **kwargs)
  1877. self.intermediate_size = self.hparams["intermediate_size"]
  1878. def get_tensors(self):
  1879. for name, data in super().get_tensors():
  1880. if 'gated_layers' in name:
  1881. d1 = data[:self.intermediate_size, :]
  1882. name1 = name.replace('gated_layers', 'gated_layers_w')
  1883. d2 = data[self.intermediate_size:, :]
  1884. name2 = name.replace('gated_layers', 'gated_layers_v')
  1885. yield name1, d1
  1886. yield name2, d2
  1887. continue
  1888. yield name, data
  1889. def set_vocab(self, *args, **kwargs):
  1890. tokenizer_class = 'BertTokenizer'
  1891. with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
  1892. tokenizer_class = json.load(f)['tokenizer_class']
  1893. if tokenizer_class == 'BertTokenizer':
  1894. super().set_vocab()
  1895. elif tokenizer_class == 'RobertaTokenizer':
  1896. self._set_vocab_gpt2()
  1897. self.gguf_writer.add_token_type_count(2)
  1898. else:
  1899. raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
  1900. self.gguf_writer.add_add_bos_token(True)
  1901. self.gguf_writer.add_add_eos_token(True)
  1902. ###### CONVERSION LOGIC ######
  1903. # tree of lazy tensors
  1904. class LazyTorchTensor(gguf.LazyBase):
  1905. _tensor_type = torch.Tensor
  1906. # to keep the type-checker happy
  1907. dtype: torch.dtype
  1908. shape: torch.Size
  1909. # only used when converting a torch.Tensor to a np.ndarray
  1910. _dtype_map: dict[torch.dtype, type] = {
  1911. torch.float16: np.float16,
  1912. torch.float32: np.float32,
  1913. }
  1914. def numpy(self) -> gguf.LazyNumpyTensor:
  1915. dtype = self._dtype_map[self.dtype]
  1916. return gguf.LazyNumpyTensor(
  1917. meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
  1918. lazy=self._lazy,
  1919. args=(self,),
  1920. func=(lambda s: s[0].numpy())
  1921. )
  1922. @classmethod
  1923. def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
  1924. return torch.empty(size=shape, dtype=dtype, device="meta")
  1925. @classmethod
  1926. def __torch_function__(cls, func, types, args=(), kwargs=None):
  1927. del types # unused
  1928. if kwargs is None:
  1929. kwargs = {}
  1930. if func is torch.Tensor.numpy:
  1931. return args[0].numpy()
  1932. return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
  1933. def parse_args() -> argparse.Namespace:
  1934. parser = argparse.ArgumentParser(
  1935. description="Convert a huggingface model to a GGML compatible file")
  1936. parser.add_argument(
  1937. "--vocab-only", action="store_true",
  1938. help="extract only the vocab",
  1939. )
  1940. parser.add_argument(
  1941. "--awq-path", type=Path, default=None,
  1942. help="Path to scale awq cache file",
  1943. )
  1944. parser.add_argument(
  1945. "--outfile", type=Path,
  1946. help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
  1947. )
  1948. parser.add_argument(
  1949. "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
  1950. help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
  1951. )
  1952. parser.add_argument(
  1953. "--bigendian", action="store_true",
  1954. help="model is executed on big endian machine",
  1955. )
  1956. parser.add_argument(
  1957. "model", type=Path,
  1958. help="directory containing model file",
  1959. )
  1960. parser.add_argument(
  1961. "--use-temp-file", action="store_true",
  1962. help="use the tempfile library while processing (helpful when running out of memory, process killed)",
  1963. )
  1964. parser.add_argument(
  1965. "--no-lazy", action="store_true",
  1966. help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
  1967. )
  1968. parser.add_argument(
  1969. "--model-name", type=str, default=None,
  1970. help="name of the model",
  1971. )
  1972. parser.add_argument(
  1973. "--verbose", action="store_true",
  1974. help="increase output verbosity",
  1975. )
  1976. return parser.parse_args()
  1977. def main() -> None:
  1978. args = parse_args()
  1979. logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
  1980. dir_model = args.model
  1981. if args.awq_path:
  1982. sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
  1983. from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
  1984. tmp_model_path = args.model / "weighted_model"
  1985. dir_model = tmp_model_path
  1986. if tmp_model_path.is_dir():
  1987. logger.info(f"{tmp_model_path} exists as a weighted model.")
  1988. else:
  1989. tmp_model_path.mkdir(parents=True, exist_ok=True)
  1990. logger.info("Saving new weighted model ...")
  1991. add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
  1992. logger.info(f"Saved weighted model at {tmp_model_path}.")
  1993. if not dir_model.is_dir():
  1994. logger.error(f'Error: {args.model} is not a directory')
  1995. sys.exit(1)
  1996. ftype_map: dict[str, gguf.LlamaFileType] = {
  1997. "f32": gguf.LlamaFileType.ALL_F32,
  1998. "f16": gguf.LlamaFileType.MOSTLY_F16,
  1999. "bf16": gguf.LlamaFileType.MOSTLY_BF16,
  2000. "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
  2001. "auto": gguf.LlamaFileType.GUESSED,
  2002. }
  2003. if args.outfile is not None:
  2004. fname_out = args.outfile
  2005. else:
  2006. # output in the same directory as the model by default
  2007. fname_out = dir_model / 'ggml-model-{ftype}.gguf'
  2008. logger.info(f"Loading model: {dir_model.name}")
  2009. hparams = Model.load_hparams(dir_model)
  2010. with torch.inference_mode():
  2011. model_class = Model.from_model_architecture(hparams["architectures"][0])
  2012. model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
  2013. logger.info("Set model parameters")
  2014. model_instance.set_gguf_parameters()
  2015. logger.info("Set model tokenizer")
  2016. model_instance.set_vocab()
  2017. model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
  2018. if args.vocab_only:
  2019. logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
  2020. model_instance.write_vocab()
  2021. else:
  2022. logger.info(f"Exporting model to '{model_instance.fname_out}'")
  2023. model_instance.write()
  2024. logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
  2025. if __name__ == '__main__':
  2026. main()