1
0

gemma3_convert_encoder_to_gguf.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import gguf
  2. import argparse
  3. import logging
  4. import sys
  5. import torch
  6. import json
  7. import os
  8. import numpy as np
  9. from typing import cast, ContextManager, Any, Iterator
  10. from pathlib import Path
  11. from torch import Tensor
  12. logger = logging.getLogger("gemma3-mmproj")
  13. # (copied from convert_hf_to_gguf.py)
  14. # tree of lazy tensors
  15. class LazyTorchTensor(gguf.LazyBase):
  16. _tensor_type = torch.Tensor
  17. # to keep the type-checker happy
  18. dtype: torch.dtype
  19. shape: torch.Size
  20. # only used when converting a torch.Tensor to a np.ndarray
  21. _dtype_map: dict[torch.dtype, type] = {
  22. torch.float16: np.float16,
  23. torch.float32: np.float32,
  24. }
  25. # used for safetensors slices
  26. # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
  27. # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
  28. _dtype_str_map: dict[str, torch.dtype] = {
  29. "F64": torch.float64,
  30. "F32": torch.float32,
  31. "BF16": torch.bfloat16,
  32. "F16": torch.float16,
  33. # "U64": torch.uint64,
  34. "I64": torch.int64,
  35. # "U32": torch.uint32,
  36. "I32": torch.int32,
  37. # "U16": torch.uint16,
  38. "I16": torch.int16,
  39. "U8": torch.uint8,
  40. "I8": torch.int8,
  41. "BOOL": torch.bool,
  42. "F8_E4M3": torch.float8_e4m3fn,
  43. "F8_E5M2": torch.float8_e5m2,
  44. }
  45. def numpy(self) -> gguf.LazyNumpyTensor:
  46. dtype = self._dtype_map[self.dtype]
  47. return gguf.LazyNumpyTensor(
  48. meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
  49. args=(self,),
  50. func=(lambda s: s.numpy())
  51. )
  52. @classmethod
  53. def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
  54. return torch.empty(size=shape, dtype=dtype, device="meta")
  55. @classmethod
  56. def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
  57. dtype = cls._dtype_str_map[st_slice.get_dtype()]
  58. shape: tuple[int, ...] = tuple(st_slice.get_shape())
  59. lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
  60. return cast(torch.Tensor, lazy)
  61. @classmethod
  62. def __torch_function__(cls, func, types, args=(), kwargs=None):
  63. del types # unused
  64. if kwargs is None:
  65. kwargs = {}
  66. if func is torch.Tensor.numpy:
  67. return args[0].numpy()
  68. return cls._wrap_fn(func)(*args, **kwargs)
  69. class Gemma3VisionTower:
  70. hparams: dict
  71. gguf_writer: gguf.GGUFWriter
  72. fname_out: Path
  73. ftype: gguf.LlamaFileType
  74. @staticmethod
  75. def load_hparams(dir_model: Path):
  76. with open(dir_model / "config.json", "r", encoding="utf-8") as f:
  77. return json.load(f)
  78. @staticmethod
  79. def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
  80. part_names: list[str] = []
  81. for filename in os.listdir(dir_model):
  82. if filename.startswith(prefix) and filename.endswith(suffix):
  83. part_names.append(filename)
  84. part_names.sort()
  85. return part_names
  86. def __init__(self,
  87. dir_model: Path,
  88. fname_out: Path,
  89. ftype: gguf.LlamaFileType,
  90. is_big_endian: bool,):
  91. hparams = Gemma3VisionTower.load_hparams(dir_model)
  92. self.hparams = hparams
  93. self.fname_out = fname_out
  94. self.ftype = ftype
  95. endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
  96. self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
  97. text_config = hparams["text_config"]
  98. vision_config = hparams["vision_config"]
  99. assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
  100. assert text_config is not None
  101. assert vision_config is not None
  102. self.gguf_writer.add_string ("clip.projector_type", "gemma3")
  103. self.gguf_writer.add_bool ("clip.has_text_encoder", False)
  104. self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
  105. self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
  106. self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
  107. self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
  108. self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
  109. self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
  110. self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
  111. self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
  112. self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
  113. self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
  114. # default values taken from HF tranformers code
  115. self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
  116. self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
  117. self.gguf_writer.add_bool ("clip.use_gelu", True)
  118. # load tensors
  119. for name, data_torch in self.get_tensors(dir_model):
  120. # convert any unsupported data types to float32
  121. if data_torch.dtype not in (torch.float16, torch.float32):
  122. data_torch = data_torch.to(torch.float32)
  123. self.add_tensor(name, data_torch)
  124. def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
  125. part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
  126. tensor_names_from_parts: set[str] = set()
  127. for part_name in part_names:
  128. logger.info(f"gguf: loading model part '{part_name}'")
  129. from safetensors import safe_open
  130. ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
  131. with ctx as model_part:
  132. tensor_names_from_parts.update(model_part.keys())
  133. for name in model_part.keys():
  134. data = model_part.get_slice(name)
  135. data = LazyTorchTensor.from_safetensors_slice(data)
  136. yield name, data
  137. def add_tensor(self, name: str, data_torch: Tensor):
  138. is_1d = len(data_torch.shape) == 1
  139. is_embd = ".embeddings." in name
  140. old_dtype = data_torch.dtype
  141. can_quantize = not is_1d and not is_embd
  142. data_qtype = gguf.GGMLQuantizationType.F32
  143. # this is to support old checkpoint
  144. # TODO: remove this when we have the final model
  145. name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
  146. name = name.replace("multimodal_projector.", "multi_modal_projector.")
  147. # filter only vision tensors
  148. if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
  149. return
  150. # prefix
  151. name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
  152. name = name.replace("vision_tower.vision_model.", "v.")
  153. # projector and input embd
  154. name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
  155. name = name.replace(".embeddings.position_embedding.", ".position_embd.")
  156. name = name.replace(
  157. "multi_modal_projector.mm_input_projection_weight",
  158. "mm.input_projection.weight"
  159. )
  160. name = name.replace(
  161. "multi_modal_projector.mm_soft_emb_norm.weight",
  162. "mm.soft_emb_norm.weight"
  163. )
  164. name = name.replace("post_layernorm.", "post_ln.")
  165. # each block
  166. name = name.replace(".self_attn.k_proj.", ".attn_k.")
  167. name = name.replace(".self_attn.v_proj.", ".attn_v.")
  168. name = name.replace(".self_attn.q_proj.", ".attn_q.")
  169. name = name.replace(".self_attn.out_proj.", ".attn_out.")
  170. name = name.replace(".layer_norm1.", ".ln1.")
  171. name = name.replace(".layer_norm2.", ".ln2.")
  172. name = name.replace(".mlp.fc1.", ".ffn_down.")
  173. name = name.replace(".mlp.fc2.", ".ffn_up.")
  174. if can_quantize:
  175. if self.ftype == gguf.LlamaFileType.ALL_F32:
  176. data_qtype = gguf.GGMLQuantizationType.F32
  177. elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
  178. data_qtype = gguf.GGMLQuantizationType.F16
  179. elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
  180. data_qtype = gguf.GGMLQuantizationType.BF16
  181. elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
  182. data_qtype = gguf.GGMLQuantizationType.Q8_0
  183. else:
  184. raise ValueError(f"Unsupported file type: {self.ftype}")
  185. # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
  186. # the other norm values are part of SigLIP model, and they are already correct
  187. # ref code: Gemma3RMSNorm
  188. if "soft_emb_norm.weight" in name:
  189. logger.info(f"Correcting norm value for '{name}'")
  190. data_torch = data_torch + 1
  191. data = data_torch.numpy()
  192. try:
  193. data = gguf.quants.quantize(data, data_qtype)
  194. except Exception as e:
  195. logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
  196. data_qtype = gguf.GGMLQuantizationType.F16
  197. data = gguf.quants.quantize(data, data_qtype)
  198. # reverse shape to make it similar to the internal ggml dimension order
  199. shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
  200. logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
  201. self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
  202. def write(self):
  203. self.gguf_writer.write_header_to_file(path=self.fname_out)
  204. self.gguf_writer.write_kv_data_to_file()
  205. self.gguf_writer.write_tensors_to_file(progress=True)
  206. self.gguf_writer.close()
  207. def parse_args() -> argparse.Namespace:
  208. parser = argparse.ArgumentParser(
  209. description="Convert Gemma 3 vision tower safetensors to GGUF format",)
  210. parser.add_argument(
  211. "--outfile", type=Path, default="mmproj.gguf",
  212. help="path to write to",
  213. )
  214. parser.add_argument(
  215. "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
  216. help="output format",
  217. )
  218. parser.add_argument(
  219. "--bigendian", action="store_true",
  220. help="model is executed on big endian machine",
  221. )
  222. parser.add_argument(
  223. "model", type=Path,
  224. help="directory containing model file",
  225. nargs="?",
  226. )
  227. parser.add_argument(
  228. "--verbose", action="store_true",
  229. help="increase output verbosity",
  230. )
  231. args = parser.parse_args()
  232. if args.model is None:
  233. parser.error("the following arguments are required: model")
  234. return args
  235. def main() -> None:
  236. args = parse_args()
  237. if args.verbose:
  238. logging.basicConfig(level=logging.DEBUG)
  239. else:
  240. logging.basicConfig(level=logging.INFO)
  241. dir_model = args.model
  242. if not dir_model.is_dir():
  243. logger.error(f'Error: {args.model} is not a directory')
  244. sys.exit(1)
  245. ftype_map: dict[str, gguf.LlamaFileType] = {
  246. "f32": gguf.LlamaFileType.ALL_F32,
  247. "f16": gguf.LlamaFileType.MOSTLY_F16,
  248. "bf16": gguf.LlamaFileType.MOSTLY_BF16,
  249. "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
  250. }
  251. logger.info(f"Loading model: {dir_model.name}")
  252. with torch.inference_mode():
  253. gemma3_vision_tower = Gemma3VisionTower(
  254. dir_model=dir_model,
  255. fname_out=args.outfile,
  256. ftype=ftype_map[args.outtype],
  257. is_big_endian=args.bigendian,
  258. )
  259. gemma3_vision_tower.write()
  260. if __name__ == '__main__':
  261. main()