minicpmv-convert-image-encoder-to-gguf.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # coding=utf-8
  2. # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ PyTorch Siglip model. """
  16. # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
  17. import os
  18. import math
  19. import warnings
  20. import numpy as np
  21. import torch
  22. import torch.nn.functional as F
  23. from torch import nn
  24. from torch.nn.init import _calculate_fan_in_and_fan_out
  25. from transformers.activations import ACT2FN
  26. from transformers.modeling_utils import PreTrainedModel
  27. from transformers.configuration_utils import PretrainedConfig
  28. from transformers.utils import (
  29. logging,
  30. )
  31. from transformers.utils import logging
  32. logger = logging.get_logger(__name__)
  33. class SiglipVisionConfig(PretrainedConfig):
  34. r"""
  35. This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
  36. Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
  37. configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
  38. [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
  39. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  40. documentation from [`PretrainedConfig`] for more information.
  41. Args:
  42. hidden_size (`int`, *optional*, defaults to 768):
  43. Dimensionality of the encoder layers and the pooler layer.
  44. intermediate_size (`int`, *optional*, defaults to 3072):
  45. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  46. num_hidden_layers (`int`, *optional*, defaults to 12):
  47. Number of hidden layers in the Transformer encoder.
  48. num_attention_heads (`int`, *optional*, defaults to 12):
  49. Number of attention heads for each attention layer in the Transformer encoder.
  50. num_channels (`int`, *optional*, defaults to 3):
  51. Number of channels in the input images.
  52. image_size (`int`, *optional*, defaults to 224):
  53. The size (resolution) of each image.
  54. patch_size (`int`, *optional*, defaults to 16):
  55. The size (resolution) of each patch.
  56. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  57. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  58. `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
  59. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  60. The epsilon used by the layer normalization layers.
  61. attention_dropout (`float`, *optional*, defaults to 0.0):
  62. The dropout ratio for the attention probabilities.
  63. Example:
  64. ```python
  65. >>> from transformers import SiglipVisionConfig, SiglipVisionModel
  66. >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
  67. >>> configuration = SiglipVisionConfig()
  68. >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
  69. >>> model = SiglipVisionModel(configuration)
  70. >>> # Accessing the model configuration
  71. >>> configuration = model.config
  72. ```"""
  73. model_type = "siglip_vision_model"
  74. def __init__(
  75. self,
  76. hidden_size=768,
  77. intermediate_size=3072,
  78. num_hidden_layers=12,
  79. num_attention_heads=12,
  80. num_channels=3,
  81. image_size=224,
  82. patch_size=16,
  83. hidden_act="gelu_pytorch_tanh",
  84. layer_norm_eps=1e-6,
  85. attention_dropout=0.0,
  86. **kwargs,
  87. ):
  88. super().__init__(**kwargs)
  89. self.hidden_size = hidden_size
  90. self.intermediate_size = intermediate_size
  91. self.num_hidden_layers = num_hidden_layers
  92. self.num_attention_heads = num_attention_heads
  93. self.num_channels = num_channels
  94. self.patch_size = patch_size
  95. self.image_size = image_size
  96. self.attention_dropout = attention_dropout
  97. self.layer_norm_eps = layer_norm_eps
  98. self.hidden_act = hidden_act
  99. _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
  100. SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
  101. "google/siglip-base-patch16-224",
  102. # See all SigLIP models at https://huggingface.co/models?filter=siglip
  103. ]
  104. # Copied from transformers.models.llama.modeling_llama._get_unpad_data
  105. def _get_unpad_data(attention_mask):
  106. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  107. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  108. max_seqlen_in_batch = seqlens_in_batch.max().item()
  109. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  110. return (
  111. indices,
  112. cu_seqlens,
  113. max_seqlen_in_batch,
  114. )
  115. def _trunc_normal_(tensor, mean, std, a, b):
  116. # Cut & paste from PyTorch official master until it's in a few official releases - RW
  117. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  118. def norm_cdf(x):
  119. # Computes standard normal cumulative distribution function
  120. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  121. if (mean < a - 2 * std) or (mean > b + 2 * std):
  122. warnings.warn(
  123. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  124. "The distribution of values may be incorrect.",
  125. stacklevel=2,
  126. )
  127. # Values are generated by using a truncated uniform distribution and
  128. # then using the inverse CDF for the normal distribution.
  129. # Get upper and lower cdf values
  130. l = norm_cdf((a - mean) / std)
  131. u = norm_cdf((b - mean) / std)
  132. # Uniformly fill tensor with values from [l, u], then translate to
  133. # [2l-1, 2u-1].
  134. tensor.uniform_(2 * l - 1, 2 * u - 1)
  135. # Use inverse cdf transform for normal distribution to get truncated
  136. # standard normal
  137. if tensor.dtype in [torch.float16, torch.bfloat16]:
  138. # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
  139. og_dtype = tensor.dtype
  140. tensor = tensor.to(torch.float32)
  141. tensor.erfinv_()
  142. tensor = tensor.to(og_dtype)
  143. else:
  144. tensor.erfinv_()
  145. # Transform to proper mean, std
  146. tensor.mul_(std * math.sqrt(2.0))
  147. tensor.add_(mean)
  148. # Clamp to ensure it's in the proper range
  149. if tensor.dtype == torch.float16:
  150. # The `clamp_` op is not (yet?) defined in float16+cpu
  151. tensor = tensor.to(torch.float32)
  152. tensor.clamp_(min=a, max=b)
  153. tensor = tensor.to(torch.float16)
  154. else:
  155. tensor.clamp_(min=a, max=b)
  156. def trunc_normal_tf_(
  157. tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
  158. ):
  159. """Fills the input Tensor with values drawn from a truncated
  160. normal distribution. The values are effectively drawn from the
  161. normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
  162. with values outside :math:`[a, b]` redrawn until they are within
  163. the bounds. The method used for generating the random values works
  164. best when :math:`a \\leq \text{mean} \\leq b`.
  165. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
  166. bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
  167. and the result is subsquently scaled and shifted by the mean and std args.
  168. Args:
  169. tensor: an n-dimensional `torch.Tensor`
  170. mean: the mean of the normal distribution
  171. std: the standard deviation of the normal distribution
  172. a: the minimum cutoff value
  173. b: the maximum cutoff value
  174. """
  175. with torch.no_grad():
  176. _trunc_normal_(tensor, 0, 1.0, a, b)
  177. tensor.mul_(std).add_(mean)
  178. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  179. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  180. denom = fan_in
  181. if mode == "fan_in":
  182. denom = fan_in
  183. elif mode == "fan_out":
  184. denom = fan_out
  185. elif mode == "fan_avg":
  186. denom = (fan_in + fan_out) / 2
  187. variance = scale / denom
  188. if distribution == "truncated_normal":
  189. # constant is stddev of standard normal truncated to (-2, 2)
  190. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  191. elif distribution == "normal":
  192. with torch.no_grad():
  193. tensor.normal_(std=math.sqrt(variance))
  194. elif distribution == "uniform":
  195. bound = math.sqrt(3 * variance)
  196. with torch.no_grad():
  197. tensor.uniform_(-bound, bound)
  198. else:
  199. raise ValueError(f"invalid distribution {distribution}")
  200. def lecun_normal_(tensor):
  201. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  202. def default_flax_embed_init(tensor):
  203. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  204. class SiglipVisionEmbeddings(nn.Module):
  205. def __init__(self, config: SiglipVisionConfig):
  206. super().__init__()
  207. self.config = config
  208. self.embed_dim = config.hidden_size
  209. self.image_size = config.image_size
  210. self.patch_size = config.patch_size
  211. self.patch_embedding = nn.Conv2d(
  212. in_channels=config.num_channels,
  213. out_channels=self.embed_dim,
  214. kernel_size=self.patch_size,
  215. stride=self.patch_size,
  216. padding="valid",
  217. )
  218. self.num_patches_per_side = self.image_size // self.patch_size
  219. self.num_patches = self.num_patches_per_side**2
  220. self.num_positions = self.num_patches
  221. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  222. class SiglipAttention(nn.Module):
  223. """Multi-headed attention from 'Attention Is All You Need' paper"""
  224. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  225. def __init__(self, config):
  226. super().__init__()
  227. self.config = config
  228. self.embed_dim = config.hidden_size
  229. self.num_heads = config.num_attention_heads
  230. self.head_dim = self.embed_dim // self.num_heads
  231. if self.head_dim * self.num_heads != self.embed_dim:
  232. raise ValueError(
  233. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  234. f" {self.num_heads})."
  235. )
  236. self.scale = self.head_dim**-0.5
  237. self.dropout = config.attention_dropout
  238. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  239. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  240. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  241. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  242. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  243. class SiglipMLP(nn.Module):
  244. def __init__(self, config):
  245. super().__init__()
  246. self.config = config
  247. self.activation_fn = ACT2FN[config.hidden_act]
  248. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  249. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  250. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
  251. class SiglipEncoderLayer(nn.Module):
  252. def __init__(self, config: SiglipVisionConfig):
  253. super().__init__()
  254. self.embed_dim = config.hidden_size
  255. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  256. self.self_attn = (
  257. SiglipAttention(config)
  258. )
  259. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  260. self.mlp = SiglipMLP(config)
  261. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  262. class SiglipPreTrainedModel(PreTrainedModel):
  263. """
  264. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  265. models.
  266. """
  267. config_class = SiglipVisionConfig
  268. base_model_prefix = "siglip"
  269. supports_gradient_checkpointing = True
  270. def _init_weights(self, module):
  271. """Initialize the weights"""
  272. if isinstance(module, SiglipVisionEmbeddings):
  273. width = self.config.hidden_size
  274. nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  275. elif isinstance(module, nn.Embedding):
  276. default_flax_embed_init(module.weight)
  277. elif isinstance(module, SiglipAttention):
  278. nn.init.normal_(module.q_proj.weight)
  279. nn.init.normal_(module.k_proj.weight)
  280. nn.init.normal_(module.v_proj.weight)
  281. nn.init.normal_(module.out_proj.weight)
  282. nn.init.zeros_(module.q_proj.bias)
  283. nn.init.zeros_(module.k_proj.bias)
  284. nn.init.zeros_(module.v_proj.bias)
  285. nn.init.zeros_(module.out_proj.bias)
  286. elif isinstance(module, SiglipMLP):
  287. nn.init.normal_(module.fc1.weight)
  288. nn.init.normal_(module.fc2.weight)
  289. nn.init.normal_(module.fc1.bias, std=1e-6)
  290. nn.init.normal_(module.fc2.bias, std=1e-6)
  291. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  292. lecun_normal_(module.weight)
  293. if module.bias is not None:
  294. nn.init.zeros_(module.bias)
  295. elif isinstance(module, nn.LayerNorm):
  296. module.bias.data.zero_()
  297. module.weight.data.fill_(1.0)
  298. SIGLIP_START_DOCSTRING = r"""
  299. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  300. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  301. etc.)
  302. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  303. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  304. and behavior.
  305. Parameters:
  306. config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
  307. Initializing with a config file does not load the weights associated with the model, only the
  308. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  309. """
  310. SIGLIP_VISION_INPUTS_DOCSTRING = r"""
  311. Args:
  312. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  313. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  314. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
  315. output_attentions (`bool`, *optional*):
  316. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  317. tensors for more detail.
  318. output_hidden_states (`bool`, *optional*):
  319. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  320. more detail.
  321. return_dict (`bool`, *optional*):
  322. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  323. """
  324. # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
  325. class SiglipEncoder(nn.Module):
  326. """
  327. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  328. [`SiglipEncoderLayer`].
  329. Args:
  330. config: SiglipConfig
  331. """
  332. def __init__(self, config: SiglipVisionConfig):
  333. super().__init__()
  334. self.config = config
  335. self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  336. self.gradient_checkpointing = False
  337. class SiglipVisionTransformer(SiglipPreTrainedModel):
  338. config_class = SiglipVisionConfig
  339. main_input_name = "pixel_values"
  340. _supports_flash_attn_2 = True
  341. def __init__(self, config: SiglipVisionConfig):
  342. super().__init__(config)
  343. self.config = config
  344. embed_dim = config.hidden_size
  345. self.embeddings = SiglipVisionEmbeddings(config)
  346. self.encoder = SiglipEncoder(config)
  347. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  348. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  349. # Initialize weights and apply final processing
  350. self.post_init()
  351. def get_input_embeddings(self) -> nn.Module:
  352. return self.embeddings.patch_embedding
  353. import argparse
  354. import json
  355. import re
  356. import numpy as np
  357. from gguf import *
  358. from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
  359. from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
  360. TEXT = "clip.text"
  361. VISION = "clip.vision"
  362. def add_key_str(raw_key: str, arch: str) -> str:
  363. return raw_key.format(arch=arch)
  364. def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool:
  365. if name in (
  366. "logit_scale",
  367. "text_model.embeddings.position_ids",
  368. "vision_model.embeddings.position_ids",
  369. ):
  370. return True
  371. if has_minicpmv and name in ["visual_projection.weight"]:
  372. return True
  373. if name.startswith("v") and not has_vision:
  374. return True
  375. if name.startswith("t") and not has_text:
  376. return True
  377. return False
  378. def get_tensor_name(name: str) -> str:
  379. if "projection" in name:
  380. return name
  381. if "mm_projector" in name:
  382. name = name.replace("model.mm_projector", "mm")
  383. name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
  384. name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
  385. return name
  386. return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
  387. def bytes_to_unicode():
  388. """
  389. Returns list of utf-8 byte and a corresponding list of unicode strings.
  390. The reversible bpe codes work on unicode strings.
  391. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  392. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  393. This is a significant percentage of your normal, say, 32K bpe vocab.
  394. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  395. And avoids mapping to whitespace/control characters the bpe code barfs on.
  396. """
  397. bs = (
  398. list(range(ord("!"), ord("~") + 1))
  399. + list(range(ord("¡"), ord("¬") + 1))
  400. + list(range(ord("®"), ord("ÿ") + 1))
  401. )
  402. cs = bs[:]
  403. n = 0
  404. for b in range(2**8):
  405. if b not in bs:
  406. bs.append(b)
  407. cs.append(2**8 + n)
  408. n += 1
  409. cs = [chr(n) for n in cs]
  410. return dict(zip(bs, cs))
  411. ap = argparse.ArgumentParser()
  412. ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
  413. ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
  414. ap.add_argument("--text-only", action="store_true", required=False,
  415. help="Save a text-only model. It can't be used to encode images")
  416. ap.add_argument("--vision-only", action="store_true", required=False,
  417. help="Save a vision-only model. It can't be used to encode texts")
  418. ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
  419. help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
  420. ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
  421. help="The clip model is from openclip (for ViT-SO400M type))")
  422. ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.")
  423. ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
  424. ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
  425. # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
  426. # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
  427. default_image_mean = [0.5, 0.5, 0.5]
  428. default_image_std = [0.5, 0.5, 0.5]
  429. ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
  430. ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
  431. ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4; MiniCPM-V 4.0 use 5; MiniCPM-o-4.0 use 6', default=2)
  432. # with proper
  433. args = ap.parse_args()
  434. if args.text_only and args.vision_only:
  435. print("--text-only and --image-only arguments cannot be specified at the same time.")
  436. exit(1)
  437. if args.use_f32:
  438. print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
  439. # output in the same directory as the model if output_dir is None
  440. dir_model = args.model_dir
  441. # Read config.json to get actual model configuration
  442. config_path = os.path.join(dir_model, "config.json")
  443. model_config = {}
  444. if os.path.isfile(config_path):
  445. with open(config_path, "r", encoding="utf-8") as f:
  446. model_config = json.load(f)
  447. print(f"Loaded config from {config_path}")
  448. else:
  449. print(f"Warning: config.json not found at {config_path}")
  450. # If minicpmv_projector is not specified but the default path exists, use the default path
  451. if args.minicpmv_projector is None:
  452. default_projector_path = os.path.join(dir_model, "minicpmv.projector")
  453. if os.path.isfile(default_projector_path):
  454. args.minicpmv_projector = default_projector_path
  455. print(f"Found default projector file: {default_projector_path}")
  456. # If output_dir is not specified, use model_dir as the default value
  457. if args.output_dir is None:
  458. args.output_dir = dir_model
  459. if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
  460. vocab = None
  461. tokens = None
  462. else:
  463. with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
  464. vocab = json.load(f)
  465. tokens = [key for key in vocab]
  466. # possible data types
  467. # ftype == 0 -> float32
  468. # ftype == 1 -> float16
  469. #
  470. # map from ftype to string
  471. ftype_str = ["f32", "f16"]
  472. ftype = 1
  473. if args.use_f32:
  474. ftype = 0
  475. # if args.clip_model_is_vision or args.clip_model_is_openclip:
  476. # model = CLIPVisionModel.from_pretrained(dir_model)
  477. # processor = None
  478. # else:
  479. # model = CLIPModel.from_pretrained(dir_model)
  480. # processor = CLIPProcessor.from_pretrained(dir_model)
  481. minicpmv_version = args.minicpmv_version
  482. # Use actual config values instead of hardcoded ones
  483. if model_config:
  484. # For the projector/resampler, use the main model's hidden_size
  485. emb_dim = model_config.get("hidden_size", 1536)
  486. # For the vision model, use vision_config values
  487. vision_config_dict = model_config.get("vision_config", {})
  488. default_vision_config = {
  489. "hidden_size": vision_config_dict.get("hidden_size", 1152),
  490. "image_size": vision_config_dict.get("image_size", 980),
  491. "intermediate_size": vision_config_dict.get("intermediate_size", 4304),
  492. "model_type": vision_config_dict.get("model_type", "siglip"),
  493. "num_attention_heads": vision_config_dict.get("num_attention_heads", 16),
  494. "num_hidden_layers": vision_config_dict.get("num_hidden_layers", 27),
  495. "patch_size": vision_config_dict.get("patch_size", 14),
  496. }
  497. # Use vision model's num_hidden_layers for block_count
  498. block_count = vision_config_dict.get("num_hidden_layers", 27)
  499. print(f"Using config values: emb_dim={emb_dim}, block_count={block_count}")
  500. print(f"Vision config: {default_vision_config}")
  501. else:
  502. # Fallback to original hardcoded logic if config.json not found
  503. emb_dim = 4096
  504. block_count = 26
  505. if minicpmv_version == 1:
  506. emb_dim = 2304
  507. block_count = 26
  508. elif minicpmv_version == 2:
  509. emb_dim = 4096
  510. block_count = 27
  511. elif minicpmv_version == 3:
  512. emb_dim = 3584
  513. block_count = 27
  514. elif minicpmv_version == 4:
  515. emb_dim = 3584
  516. block_count = 27
  517. elif minicpmv_version == 5:
  518. emb_dim = 2560
  519. block_count = 27
  520. elif minicpmv_version == 6:
  521. emb_dim = 4096
  522. block_count = 27
  523. default_vision_config = {
  524. "hidden_size": 1152,
  525. "image_size": 980,
  526. "intermediate_size": 4304,
  527. "model_type": "idefics2",
  528. "num_attention_heads": 16,
  529. "num_hidden_layers": 27,
  530. "patch_size": 14,
  531. }
  532. vision_config = Idefics2VisionConfig(**default_vision_config)
  533. model = Idefics2VisionTransformer(vision_config)
  534. if minicpmv_version == 3 or (model_config and model_config.get("vision_config", {}).get("model_type") == "siglip"):
  535. vision_config = SiglipVisionConfig(**default_vision_config)
  536. model = SiglipVisionTransformer(vision_config)
  537. elif minicpmv_version == 4:
  538. vision_config = SiglipVisionConfig(**default_vision_config)
  539. model = SiglipVisionTransformer(vision_config)
  540. elif minicpmv_version == 5:
  541. default_vision_config["model_type"] = "siglip_vision_model"
  542. vision_config = SiglipVisionConfig(**default_vision_config)
  543. model = SiglipVisionTransformer(vision_config)
  544. elif minicpmv_version == 6:
  545. default_vision_config["model_type"] = "siglip_vision_model"
  546. vision_config = SiglipVisionConfig(**default_vision_config)
  547. model = SiglipVisionTransformer(vision_config)
  548. processor = None
  549. # if model.attn_pool is not None:
  550. # model.attn_pool = torch.nn.Identity()
  551. # model.blocks = model.blocks[:-1]
  552. model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip")))
  553. fname_middle = None
  554. has_text_encoder = True
  555. has_vision_encoder = True
  556. has_minicpmv_projector = False
  557. if args.text_only:
  558. fname_middle = "text-"
  559. has_vision_encoder = False
  560. elif args.minicpmv_projector is not None:
  561. fname_middle = "mmproj-"
  562. has_text_encoder = False
  563. has_minicpmv_projector = True
  564. elif args.vision_only:
  565. fname_middle = "vision-"
  566. has_text_encoder = False
  567. else:
  568. fname_middle = ""
  569. output_dir = args.output_dir
  570. os.makedirs(output_dir, exist_ok=True)
  571. output_prefix = os.path.basename(output_dir).replace("ggml_", "")
  572. fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
  573. fout = GGUFWriter(path=fname_out, arch="clip")
  574. fout.add_bool("clip.has_text_encoder", has_text_encoder)
  575. fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
  576. fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector)
  577. fout.add_file_type(ftype)
  578. if args.text_only:
  579. fout.add_description("text-only CLIP model")
  580. elif args.vision_only and not has_minicpmv_projector:
  581. fout.add_description("vision-only CLIP model")
  582. elif has_minicpmv_projector:
  583. fout.add_description("image encoder for MiniCPM-V")
  584. # add projector type
  585. fout.add_string("clip.projector_type", "resampler")
  586. fout.add_int32("clip.minicpmv_version", minicpmv_version)
  587. else:
  588. fout.add_description("two-tower CLIP model")
  589. if has_vision_encoder:
  590. # vision_model hparams - use actual config values
  591. vision_image_size = model_config.get("image_size", 448) if model_config else 448
  592. vision_patch_size = default_vision_config.get("patch_size", 14)
  593. vision_hidden_size = default_vision_config.get("hidden_size", 1152)
  594. vision_intermediate_size = default_vision_config.get("intermediate_size", 4304)
  595. vision_attention_heads = default_vision_config.get("num_attention_heads", 16)
  596. fout.add_uint32("clip.vision.image_size", vision_image_size)
  597. fout.add_uint32("clip.vision.patch_size", vision_patch_size)
  598. fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), vision_hidden_size)
  599. fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), vision_intermediate_size)
  600. fout.add_uint32("clip.vision.projection_dim", 0)
  601. fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), vision_attention_heads)
  602. fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
  603. fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
  604. # Add MiniCPM-V specific parameters
  605. query_num = model_config.get("query_num", 0) if model_config else 0
  606. resampler_emb_dim = model_config.get("hidden_size", 0) if model_config else 0
  607. fout.add_uint32("clip.minicpmv_query_num", query_num)
  608. if processor is not None:
  609. image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
  610. image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
  611. else:
  612. image_mean = args.image_mean if args.image_mean is not None else default_image_mean
  613. image_std = args.image_std if args.image_std is not None else default_image_std
  614. fout.add_array("clip.vision.image_mean", image_mean)
  615. fout.add_array("clip.vision.image_std", image_std)
  616. use_gelu = True
  617. fout.add_bool("clip.use_gelu", use_gelu)
  618. def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  619. """
  620. embed_dim: output dimension for each position
  621. pos: a list of positions to be encoded: size (M,)
  622. out: (M, D)
  623. """
  624. assert embed_dim % 2 == 0
  625. omega = np.arange(embed_dim // 2, dtype=np.float32)
  626. omega /= embed_dim / 2.
  627. omega = 1. / 10000 ** omega # (D/2,)
  628. pos = pos.reshape(-1) # (M,)
  629. out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
  630. emb_sin = np.sin(out) # (M, D/2)
  631. emb_cos = np.cos(out) # (M, D/2)
  632. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  633. return emb
  634. def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  635. assert embed_dim % 2 == 0
  636. # use half of dimensions to encode grid_h
  637. emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
  638. emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
  639. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  640. return emb
  641. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  642. def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
  643. """
  644. grid_size: int of the grid height and width
  645. return:
  646. pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  647. """
  648. if isinstance(grid_size, int):
  649. grid_h_size, grid_w_size = grid_size, grid_size
  650. else:
  651. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  652. grid_h = np.arange(grid_h_size, dtype=np.float32)
  653. grid_w = np.arange(grid_w_size, dtype=np.float32)
  654. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  655. grid = np.stack(grid, axis=0)
  656. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  657. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
  658. if cls_token:
  659. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
  660. return pos_embed
  661. def _replace_name_resampler(s, v):
  662. if re.match("resampler.pos_embed", s):
  663. return {
  664. s: v,
  665. re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
  666. }
  667. if re.match("resampler.proj", s):
  668. return {
  669. re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
  670. re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
  671. }
  672. if re.match("resampler.attn.in_proj_.*", s):
  673. return {
  674. re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0],
  675. re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1],
  676. re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2],
  677. }
  678. return {s: v}
  679. if has_minicpmv_projector:
  680. projector = torch.load(args.minicpmv_projector)
  681. new_state_dict = {}
  682. for k, v in projector.items():
  683. kvs = _replace_name_resampler(k, v)
  684. for nk, nv in kvs.items():
  685. new_state_dict[nk] = nv
  686. projector = new_state_dict
  687. ftype_cur = 0
  688. for name, data in projector.items():
  689. name = get_tensor_name(name)
  690. data = data.squeeze().numpy()
  691. n_dims = len(data.shape)
  692. if ftype == 1:
  693. if name[-7:] == ".weight" and n_dims == 2:
  694. print(" Converting to float16")
  695. data = data.astype(np.float16)
  696. ftype_cur = 1
  697. else:
  698. print(" Converting to float32")
  699. data = data.astype(np.float32)
  700. ftype_cur = 0
  701. else:
  702. if data.dtype != np.float32:
  703. print(" Converting to float32")
  704. data = data.astype(np.float32)
  705. ftype_cur = 0
  706. fout.add_tensor(name, data)
  707. print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
  708. print("Projector tensors added\n")
  709. def _replace_name(s, v):
  710. s = "vision_model." + s
  711. if re.match("vision_model.embeddings.position_embedding", s):
  712. v = v.unsqueeze(0)
  713. return {s: v}
  714. return {s: v}
  715. state_dict = model.state_dict()
  716. new_state_dict = {}
  717. for k, v in state_dict.items():
  718. kvs = _replace_name(k, v)
  719. for nk, nv in kvs.items():
  720. new_state_dict[nk] = nv
  721. state_dict = new_state_dict
  722. for name, data in state_dict.items():
  723. if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector):
  724. # we don't need this
  725. print(f"skipping parameter: {name}")
  726. continue
  727. name = get_tensor_name(name)
  728. data = data.squeeze().numpy()
  729. n_dims = len(data.shape)
  730. # ftype == 0 -> float32, ftype == 1 -> float16
  731. ftype_cur = 0
  732. if n_dims == 4:
  733. print(f"tensor {name} is always saved in f16")
  734. data = data.astype(np.float16)
  735. ftype_cur = 1
  736. elif ftype == 1:
  737. if name[-7:] == ".weight" and n_dims == 2:
  738. print(" Converting to float16")
  739. data = data.astype(np.float16)
  740. ftype_cur = 1
  741. else:
  742. print(" Converting to float32")
  743. data = data.astype(np.float32)
  744. ftype_cur = 0
  745. else:
  746. if data.dtype != np.float32:
  747. print(" Converting to float32")
  748. data = data.astype(np.float32)
  749. ftype_cur = 0
  750. print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
  751. fout.add_tensor(name, data)
  752. fout.write_header_to_file()
  753. fout.write_kv_data_to_file()
  754. fout.write_tensors_to_file()
  755. fout.close()
  756. print("Done. Output file: " + fname_out)