minicpmv-convert-image-encoder-to-gguf.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. # coding=utf-8
  2. # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ PyTorch Siglip model. """
  16. # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
  17. import os
  18. import math
  19. import warnings
  20. import numpy as np
  21. import torch
  22. import torch.nn.functional as F
  23. import torch.utils.checkpoint
  24. from torch import nn
  25. from torch.nn.init import _calculate_fan_in_and_fan_out
  26. from transformers.activations import ACT2FN
  27. from transformers.modeling_utils import PreTrainedModel
  28. from transformers.configuration_utils import PretrainedConfig
  29. from transformers.utils import (
  30. logging,
  31. )
  32. from transformers.utils import logging
  33. logger = logging.get_logger(__name__)
  34. class SiglipVisionConfig(PretrainedConfig):
  35. r"""
  36. This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
  37. Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
  38. configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
  39. [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
  40. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  41. documentation from [`PretrainedConfig`] for more information.
  42. Args:
  43. hidden_size (`int`, *optional*, defaults to 768):
  44. Dimensionality of the encoder layers and the pooler layer.
  45. intermediate_size (`int`, *optional*, defaults to 3072):
  46. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  47. num_hidden_layers (`int`, *optional*, defaults to 12):
  48. Number of hidden layers in the Transformer encoder.
  49. num_attention_heads (`int`, *optional*, defaults to 12):
  50. Number of attention heads for each attention layer in the Transformer encoder.
  51. num_channels (`int`, *optional*, defaults to 3):
  52. Number of channels in the input images.
  53. image_size (`int`, *optional*, defaults to 224):
  54. The size (resolution) of each image.
  55. patch_size (`int`, *optional*, defaults to 16):
  56. The size (resolution) of each patch.
  57. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  58. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  59. `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
  60. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  61. The epsilon used by the layer normalization layers.
  62. attention_dropout (`float`, *optional*, defaults to 0.0):
  63. The dropout ratio for the attention probabilities.
  64. Example:
  65. ```python
  66. >>> from transformers import SiglipVisionConfig, SiglipVisionModel
  67. >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
  68. >>> configuration = SiglipVisionConfig()
  69. >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
  70. >>> model = SiglipVisionModel(configuration)
  71. >>> # Accessing the model configuration
  72. >>> configuration = model.config
  73. ```"""
  74. model_type = "siglip_vision_model"
  75. def __init__(
  76. self,
  77. hidden_size=768,
  78. intermediate_size=3072,
  79. num_hidden_layers=12,
  80. num_attention_heads=12,
  81. num_channels=3,
  82. image_size=224,
  83. patch_size=16,
  84. hidden_act="gelu_pytorch_tanh",
  85. layer_norm_eps=1e-6,
  86. attention_dropout=0.0,
  87. **kwargs,
  88. ):
  89. super().__init__(**kwargs)
  90. self.hidden_size = hidden_size
  91. self.intermediate_size = intermediate_size
  92. self.num_hidden_layers = num_hidden_layers
  93. self.num_attention_heads = num_attention_heads
  94. self.num_channels = num_channels
  95. self.patch_size = patch_size
  96. self.image_size = image_size
  97. self.attention_dropout = attention_dropout
  98. self.layer_norm_eps = layer_norm_eps
  99. self.hidden_act = hidden_act
  100. _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
  101. SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
  102. "google/siglip-base-patch16-224",
  103. # See all SigLIP models at https://huggingface.co/models?filter=siglip
  104. ]
  105. # Copied from transformers.models.llama.modeling_llama._get_unpad_data
  106. def _get_unpad_data(attention_mask):
  107. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  108. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  109. max_seqlen_in_batch = seqlens_in_batch.max().item()
  110. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  111. return (
  112. indices,
  113. cu_seqlens,
  114. max_seqlen_in_batch,
  115. )
  116. def _trunc_normal_(tensor, mean, std, a, b):
  117. # Cut & paste from PyTorch official master until it's in a few official releases - RW
  118. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  119. def norm_cdf(x):
  120. # Computes standard normal cumulative distribution function
  121. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  122. if (mean < a - 2 * std) or (mean > b + 2 * std):
  123. warnings.warn(
  124. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  125. "The distribution of values may be incorrect.",
  126. stacklevel=2,
  127. )
  128. # Values are generated by using a truncated uniform distribution and
  129. # then using the inverse CDF for the normal distribution.
  130. # Get upper and lower cdf values
  131. l = norm_cdf((a - mean) / std)
  132. u = norm_cdf((b - mean) / std)
  133. # Uniformly fill tensor with values from [l, u], then translate to
  134. # [2l-1, 2u-1].
  135. tensor.uniform_(2 * l - 1, 2 * u - 1)
  136. # Use inverse cdf transform for normal distribution to get truncated
  137. # standard normal
  138. if tensor.dtype in [torch.float16, torch.bfloat16]:
  139. # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
  140. og_dtype = tensor.dtype
  141. tensor = tensor.to(torch.float32)
  142. tensor.erfinv_()
  143. tensor = tensor.to(og_dtype)
  144. else:
  145. tensor.erfinv_()
  146. # Transform to proper mean, std
  147. tensor.mul_(std * math.sqrt(2.0))
  148. tensor.add_(mean)
  149. # Clamp to ensure it's in the proper range
  150. if tensor.dtype == torch.float16:
  151. # The `clamp_` op is not (yet?) defined in float16+cpu
  152. tensor = tensor.to(torch.float32)
  153. tensor.clamp_(min=a, max=b)
  154. tensor = tensor.to(torch.float16)
  155. else:
  156. tensor.clamp_(min=a, max=b)
  157. def trunc_normal_tf_(
  158. tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
  159. ):
  160. """Fills the input Tensor with values drawn from a truncated
  161. normal distribution. The values are effectively drawn from the
  162. normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
  163. with values outside :math:`[a, b]` redrawn until they are within
  164. the bounds. The method used for generating the random values works
  165. best when :math:`a \\leq \text{mean} \\leq b`.
  166. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
  167. bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
  168. and the result is subsquently scaled and shifted by the mean and std args.
  169. Args:
  170. tensor: an n-dimensional `torch.Tensor`
  171. mean: the mean of the normal distribution
  172. std: the standard deviation of the normal distribution
  173. a: the minimum cutoff value
  174. b: the maximum cutoff value
  175. """
  176. with torch.no_grad():
  177. _trunc_normal_(tensor, 0, 1.0, a, b)
  178. tensor.mul_(std).add_(mean)
  179. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  180. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  181. denom = fan_in
  182. if mode == "fan_in":
  183. denom = fan_in
  184. elif mode == "fan_out":
  185. denom = fan_out
  186. elif mode == "fan_avg":
  187. denom = (fan_in + fan_out) / 2
  188. variance = scale / denom
  189. if distribution == "truncated_normal":
  190. # constant is stddev of standard normal truncated to (-2, 2)
  191. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  192. elif distribution == "normal":
  193. with torch.no_grad():
  194. tensor.normal_(std=math.sqrt(variance))
  195. elif distribution == "uniform":
  196. bound = math.sqrt(3 * variance)
  197. with torch.no_grad():
  198. tensor.uniform_(-bound, bound)
  199. else:
  200. raise ValueError(f"invalid distribution {distribution}")
  201. def lecun_normal_(tensor):
  202. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  203. def default_flax_embed_init(tensor):
  204. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  205. class SiglipVisionEmbeddings(nn.Module):
  206. def __init__(self, config: SiglipVisionConfig):
  207. super().__init__()
  208. self.config = config
  209. self.embed_dim = config.hidden_size
  210. self.image_size = config.image_size
  211. self.patch_size = config.patch_size
  212. self.patch_embedding = nn.Conv2d(
  213. in_channels=config.num_channels,
  214. out_channels=self.embed_dim,
  215. kernel_size=self.patch_size,
  216. stride=self.patch_size,
  217. padding="valid",
  218. )
  219. self.num_patches_per_side = self.image_size // self.patch_size
  220. self.num_patches = self.num_patches_per_side**2
  221. self.num_positions = self.num_patches
  222. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  223. class SiglipAttention(nn.Module):
  224. """Multi-headed attention from 'Attention Is All You Need' paper"""
  225. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  226. def __init__(self, config):
  227. super().__init__()
  228. self.config = config
  229. self.embed_dim = config.hidden_size
  230. self.num_heads = config.num_attention_heads
  231. self.head_dim = self.embed_dim // self.num_heads
  232. if self.head_dim * self.num_heads != self.embed_dim:
  233. raise ValueError(
  234. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  235. f" {self.num_heads})."
  236. )
  237. self.scale = self.head_dim**-0.5
  238. self.dropout = config.attention_dropout
  239. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  240. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  241. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  242. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  244. class SiglipMLP(nn.Module):
  245. def __init__(self, config):
  246. super().__init__()
  247. self.config = config
  248. self.activation_fn = ACT2FN[config.hidden_act]
  249. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  250. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  251. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
  252. class SiglipEncoderLayer(nn.Module):
  253. def __init__(self, config: SiglipVisionConfig):
  254. super().__init__()
  255. self.embed_dim = config.hidden_size
  256. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  257. self.self_attn = (
  258. SiglipAttention(config)
  259. )
  260. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  261. self.mlp = SiglipMLP(config)
  262. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  263. class SiglipPreTrainedModel(PreTrainedModel):
  264. """
  265. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  266. models.
  267. """
  268. config_class = SiglipVisionConfig
  269. base_model_prefix = "siglip"
  270. supports_gradient_checkpointing = True
  271. def _init_weights(self, module):
  272. """Initialize the weights"""
  273. if isinstance(module, SiglipVisionEmbeddings):
  274. width = self.config.hidden_size
  275. nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  276. elif isinstance(module, nn.Embedding):
  277. default_flax_embed_init(module.weight)
  278. elif isinstance(module, SiglipAttention):
  279. nn.init.normal_(module.q_proj.weight)
  280. nn.init.normal_(module.k_proj.weight)
  281. nn.init.normal_(module.v_proj.weight)
  282. nn.init.normal_(module.out_proj.weight)
  283. nn.init.zeros_(module.q_proj.bias)
  284. nn.init.zeros_(module.k_proj.bias)
  285. nn.init.zeros_(module.v_proj.bias)
  286. nn.init.zeros_(module.out_proj.bias)
  287. elif isinstance(module, SiglipMLP):
  288. nn.init.normal_(module.fc1.weight)
  289. nn.init.normal_(module.fc2.weight)
  290. nn.init.normal_(module.fc1.bias, std=1e-6)
  291. nn.init.normal_(module.fc2.bias, std=1e-6)
  292. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  293. lecun_normal_(module.weight)
  294. if module.bias is not None:
  295. nn.init.zeros_(module.bias)
  296. elif isinstance(module, nn.LayerNorm):
  297. module.bias.data.zero_()
  298. module.weight.data.fill_(1.0)
  299. SIGLIP_START_DOCSTRING = r"""
  300. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  301. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  302. etc.)
  303. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  304. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  305. and behavior.
  306. Parameters:
  307. config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
  308. Initializing with a config file does not load the weights associated with the model, only the
  309. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  310. """
  311. SIGLIP_VISION_INPUTS_DOCSTRING = r"""
  312. Args:
  313. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  314. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  315. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
  316. output_attentions (`bool`, *optional*):
  317. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  318. tensors for more detail.
  319. output_hidden_states (`bool`, *optional*):
  320. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  321. more detail.
  322. return_dict (`bool`, *optional*):
  323. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  324. """
  325. # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
  326. class SiglipEncoder(nn.Module):
  327. """
  328. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  329. [`SiglipEncoderLayer`].
  330. Args:
  331. config: SiglipConfig
  332. """
  333. def __init__(self, config: SiglipVisionConfig):
  334. super().__init__()
  335. self.config = config
  336. self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  337. self.gradient_checkpointing = False
  338. class SiglipVisionTransformer(SiglipPreTrainedModel):
  339. config_class = SiglipVisionConfig
  340. main_input_name = "pixel_values"
  341. _supports_flash_attn_2 = True
  342. def __init__(self, config: SiglipVisionConfig):
  343. super().__init__(config)
  344. self.config = config
  345. embed_dim = config.hidden_size
  346. self.embeddings = SiglipVisionEmbeddings(config)
  347. self.encoder = SiglipEncoder(config)
  348. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  349. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. def get_input_embeddings(self) -> nn.Module:
  353. return self.embeddings.patch_embedding
  354. import argparse
  355. import json
  356. import re
  357. import numpy as np
  358. from gguf import *
  359. from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
  360. TEXT = "clip.text"
  361. VISION = "clip.vision"
  362. def add_key_str(raw_key: str, arch: str) -> str:
  363. return raw_key.format(arch=arch)
  364. def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool:
  365. if name in (
  366. "logit_scale",
  367. "text_model.embeddings.position_ids",
  368. "vision_model.embeddings.position_ids",
  369. ):
  370. return True
  371. if has_minicpmv and name in ["visual_projection.weight"]:
  372. return True
  373. if name.startswith("v") and not has_vision:
  374. return True
  375. if name.startswith("t") and not has_text:
  376. return True
  377. return False
  378. def get_tensor_name(name: str) -> str:
  379. if "projection" in name:
  380. return name
  381. if "mm_projector" in name:
  382. name = name.replace("model.mm_projector", "mm")
  383. name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
  384. name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
  385. return name
  386. return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
  387. def bytes_to_unicode():
  388. """
  389. Returns list of utf-8 byte and a corresponding list of unicode strings.
  390. The reversible bpe codes work on unicode strings.
  391. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  392. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  393. This is a significant percentage of your normal, say, 32K bpe vocab.
  394. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  395. And avoids mapping to whitespace/control characters the bpe code barfs on.
  396. """
  397. bs = (
  398. list(range(ord("!"), ord("~") + 1))
  399. + list(range(ord("¡"), ord("¬") + 1))
  400. + list(range(ord("®"), ord("ÿ") + 1))
  401. )
  402. cs = bs[:]
  403. n = 0
  404. for b in range(2**8):
  405. if b not in bs:
  406. bs.append(b)
  407. cs.append(2**8 + n)
  408. n += 1
  409. cs = [chr(n) for n in cs]
  410. return dict(zip(bs, cs))
  411. ap = argparse.ArgumentParser()
  412. ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
  413. ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
  414. ap.add_argument("--text-only", action="store_true", required=False,
  415. help="Save a text-only model. It can't be used to encode images")
  416. ap.add_argument("--vision-only", action="store_true", required=False,
  417. help="Save a vision-only model. It can't be used to encode texts")
  418. ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
  419. help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
  420. ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
  421. help="The clip model is from openclip (for ViT-SO400M type))")
  422. ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.")
  423. ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
  424. ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
  425. # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
  426. # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
  427. default_image_mean = [0.48145466, 0.4578275, 0.40821073]
  428. default_image_std = [0.26862954, 0.26130258, 0.27577711]
  429. ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
  430. ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
  431. ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
  432. # with proper
  433. args = ap.parse_args()
  434. if args.text_only and args.vision_only:
  435. print("--text-only and --image-only arguments cannot be specified at the same time.")
  436. exit(1)
  437. if args.use_f32:
  438. print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
  439. # output in the same directory as the model if output_dir is None
  440. dir_model = args.model_dir
  441. if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
  442. vocab = None
  443. tokens = None
  444. else:
  445. with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
  446. vocab = json.load(f)
  447. tokens = [key for key in vocab]
  448. # possible data types
  449. # ftype == 0 -> float32
  450. # ftype == 1 -> float16
  451. #
  452. # map from ftype to string
  453. ftype_str = ["f32", "f16"]
  454. ftype = 1
  455. if args.use_f32:
  456. ftype = 0
  457. # if args.clip_model_is_vision or args.clip_model_is_openclip:
  458. # model = CLIPVisionModel.from_pretrained(dir_model)
  459. # processor = None
  460. # else:
  461. # model = CLIPModel.from_pretrained(dir_model)
  462. # processor = CLIPProcessor.from_pretrained(dir_model)
  463. minicpmv_version = args.minicpmv_version
  464. emb_dim = 4096
  465. block_count = 26
  466. if minicpmv_version == 1:
  467. emb_dim = 2304
  468. block_count = 26
  469. elif minicpmv_version == 2:
  470. emb_dim = 4096
  471. block_count = 27
  472. elif minicpmv_version == 3:
  473. emb_dim = 3584
  474. block_count = 27
  475. elif minicpmv_version == 4:
  476. emb_dim = 3584
  477. block_count = 27
  478. default_vision_config = {
  479. "hidden_size": 1152,
  480. "image_size": 980,
  481. "intermediate_size": 4304,
  482. "model_type": "idefics2",
  483. "num_attention_heads": 16,
  484. "num_hidden_layers": 27,
  485. "patch_size": 14,
  486. }
  487. vision_config = Idefics2VisionConfig(**default_vision_config)
  488. model = Idefics2VisionTransformer(vision_config)
  489. if minicpmv_version == 3:
  490. vision_config = SiglipVisionConfig(**default_vision_config)
  491. model = SiglipVisionTransformer(vision_config)
  492. elif minicpmv_version == 4:
  493. vision_config = SiglipVisionConfig(**default_vision_config)
  494. model = SiglipVisionTransformer(vision_config)
  495. processor = None
  496. # if model.attn_pool is not None:
  497. # model.attn_pool = torch.nn.Identity()
  498. # model.blocks = model.blocks[:-1]
  499. model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip")))
  500. fname_middle = None
  501. has_text_encoder = True
  502. has_vision_encoder = True
  503. has_minicpmv_projector = False
  504. if args.text_only:
  505. fname_middle = "text-"
  506. has_vision_encoder = False
  507. elif args.minicpmv_projector is not None:
  508. fname_middle = "mmproj-"
  509. has_text_encoder = False
  510. has_minicpmv_projector = True
  511. elif args.vision_only:
  512. fname_middle = "vision-"
  513. has_text_encoder = False
  514. else:
  515. fname_middle = ""
  516. output_dir = args.output_dir if args.output_dir is not None else dir_model
  517. os.makedirs(output_dir, exist_ok=True)
  518. output_prefix = os.path.basename(output_dir).replace("ggml_", "")
  519. fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
  520. fout = GGUFWriter(path=fname_out, arch="clip")
  521. fout.add_bool("clip.has_text_encoder", has_text_encoder)
  522. fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
  523. fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector)
  524. fout.add_file_type(ftype)
  525. if args.text_only:
  526. fout.add_description("text-only CLIP model")
  527. elif args.vision_only and not has_minicpmv_projector:
  528. fout.add_description("vision-only CLIP model")
  529. elif has_minicpmv_projector:
  530. fout.add_description("image encoder for MiniCPM-V")
  531. # add projector type
  532. fout.add_string("clip.projector_type", "resampler")
  533. fout.add_int32("clip.minicpmv_version", minicpmv_version)
  534. else:
  535. fout.add_description("two-tower CLIP model")
  536. if has_vision_encoder:
  537. # vision_model hparams
  538. fout.add_uint32("clip.vision.image_size", 448)
  539. fout.add_uint32("clip.vision.patch_size", 14)
  540. fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), 1152)
  541. fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), 4304)
  542. fout.add_uint32("clip.vision.projection_dim", 0)
  543. fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
  544. fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
  545. fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
  546. if processor is not None:
  547. image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
  548. image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
  549. else:
  550. image_mean = args.image_mean if args.image_mean is not None else default_image_mean
  551. image_std = args.image_std if args.image_std is not None else default_image_std
  552. fout.add_array("clip.vision.image_mean", image_mean)
  553. fout.add_array("clip.vision.image_std", image_std)
  554. use_gelu = True
  555. fout.add_bool("clip.use_gelu", use_gelu)
  556. def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  557. """
  558. embed_dim: output dimension for each position
  559. pos: a list of positions to be encoded: size (M,)
  560. out: (M, D)
  561. """
  562. assert embed_dim % 2 == 0
  563. omega = np.arange(embed_dim // 2, dtype=np.float32)
  564. omega /= embed_dim / 2.
  565. omega = 1. / 10000 ** omega # (D/2,)
  566. pos = pos.reshape(-1) # (M,)
  567. out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
  568. emb_sin = np.sin(out) # (M, D/2)
  569. emb_cos = np.cos(out) # (M, D/2)
  570. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  571. return emb
  572. def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  573. assert embed_dim % 2 == 0
  574. # use half of dimensions to encode grid_h
  575. emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
  576. emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
  577. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  578. return emb
  579. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  580. def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
  581. """
  582. grid_size: int of the grid height and width
  583. return:
  584. pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  585. """
  586. if isinstance(grid_size, int):
  587. grid_h_size, grid_w_size = grid_size, grid_size
  588. else:
  589. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  590. grid_h = np.arange(grid_h_size, dtype=np.float32)
  591. grid_w = np.arange(grid_w_size, dtype=np.float32)
  592. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  593. grid = np.stack(grid, axis=0)
  594. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  595. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
  596. if cls_token:
  597. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
  598. return pos_embed
  599. def _replace_name_resampler(s, v):
  600. if re.match("resampler.pos_embed", s):
  601. return {
  602. s: v,
  603. re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
  604. }
  605. if re.match("resampler.proj", s):
  606. return {
  607. re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
  608. re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
  609. }
  610. if re.match("resampler.attn.in_proj_.*", s):
  611. return {
  612. re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0],
  613. re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1],
  614. re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2],
  615. }
  616. return {s: v}
  617. if has_minicpmv_projector:
  618. projector = torch.load(args.minicpmv_projector)
  619. new_state_dict = {}
  620. for k, v in projector.items():
  621. kvs = _replace_name_resampler(k, v)
  622. for nk, nv in kvs.items():
  623. new_state_dict[nk] = nv
  624. projector = new_state_dict
  625. ftype_cur = 0
  626. for name, data in projector.items():
  627. name = get_tensor_name(name)
  628. data = data.squeeze().numpy()
  629. n_dims = len(data.shape)
  630. if ftype == 1:
  631. if name[-7:] == ".weight" and n_dims == 2:
  632. print(" Converting to float16")
  633. data = data.astype(np.float16)
  634. ftype_cur = 1
  635. else:
  636. print(" Converting to float32")
  637. data = data.astype(np.float32)
  638. ftype_cur = 0
  639. else:
  640. if data.dtype != np.float32:
  641. print(" Converting to float32")
  642. data = data.astype(np.float32)
  643. ftype_cur = 0
  644. fout.add_tensor(name, data)
  645. print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
  646. print("Projector tensors added\n")
  647. def _replace_name(s, v):
  648. s = "vision_model." + s
  649. if re.match("vision_model.embeddings.position_embedding", s):
  650. v = v.unsqueeze(0)
  651. return {s: v}
  652. return {s: v}
  653. state_dict = model.state_dict()
  654. new_state_dict = {}
  655. for k, v in state_dict.items():
  656. kvs = _replace_name(k, v)
  657. for nk, nv in kvs.items():
  658. new_state_dict[nk] = nv
  659. state_dict = new_state_dict
  660. for name, data in state_dict.items():
  661. if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector):
  662. # we don't need this
  663. print(f"skipping parameter: {name}")
  664. continue
  665. name = get_tensor_name(name)
  666. data = data.squeeze().numpy()
  667. n_dims = len(data.shape)
  668. # ftype == 0 -> float32, ftype == 1 -> float16
  669. ftype_cur = 0
  670. if n_dims == 4:
  671. print(f"tensor {name} is always saved in f16")
  672. data = data.astype(np.float16)
  673. ftype_cur = 1
  674. elif ftype == 1:
  675. if name[-7:] == ".weight" and n_dims == 2:
  676. print(" Converting to float16")
  677. data = data.astype(np.float16)
  678. ftype_cur = 1
  679. else:
  680. print(" Converting to float32")
  681. data = data.astype(np.float32)
  682. ftype_cur = 0
  683. else:
  684. if data.dtype != np.float32:
  685. print(" Converting to float32")
  686. data = data.astype(np.float32)
  687. ftype_cur = 0
  688. print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
  689. fout.add_tensor(name, data)
  690. fout.write_header_to_file()
  691. fout.write_kv_data_to_file()
  692. fout.write_tensors_to_file()
  693. fout.close()
  694. print("Done. Output file: " + fname_out)