qwen2_vl_surgery.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import argparse
  2. from typing import Dict, List, Optional
  3. import torch
  4. import numpy as np
  5. from gguf import *
  6. from transformers import (
  7. AutoProcessor,
  8. Qwen2VLConfig,
  9. Qwen2VLProcessor,
  10. Qwen2VLForConditionalGeneration,
  11. Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
  12. Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
  13. )
  14. VISION = "clip.vision"
  15. def k(raw_key: str, arch: str) -> str:
  16. return raw_key.format(arch=arch)
  17. def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
  18. if fullatt_block_indexes is None:
  19. return 0
  20. n_wa = fullatt_block_indexes[0]
  21. for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
  22. if b - a - 1 != n_wa:
  23. raise ValueError(
  24. f"window/full attention layer should have fix pattern of "
  25. f"for each full-attention layer followed by {n_wa} window-attention layers"
  26. )
  27. return n_wa + 1
  28. class VL2:
  29. @staticmethod
  30. def to_gguf_name(name: str) -> str:
  31. og = name
  32. name = name.replace("text_model", "t").replace("vision_model", "v")
  33. name = name.replace("blocks", "blk").replace("embeddings.", "")
  34. name = name.replace("attn.", "attn_")
  35. name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
  36. # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
  37. name = name.replace("norm1", "ln1").replace("norm2", "ln2")
  38. name = name.replace("merger.mlp", 'mm')
  39. print(f"[to_gguf_name] {og} --> {name}")
  40. return name
  41. @classmethod
  42. def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
  43. vision_model = qwen2vl.visual
  44. tensor_map = {}
  45. for name, ten in vision_model.state_dict().items():
  46. ten = ten.numpy()
  47. if 'qkv' in name:
  48. if ten.ndim == 2: # weight
  49. c3, _ = ten.shape
  50. else: # bias
  51. c3 = ten.shape[0]
  52. assert c3 % 3 == 0
  53. c = c3 // 3
  54. wq = ten[:c]
  55. wk = ten[c: c * 2]
  56. wv = ten[c * 2:]
  57. tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
  58. tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
  59. tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
  60. elif 'merger' in name:
  61. if name.endswith("ln_q.weight"):
  62. tensor_map['v.post_ln.weight'] = ten
  63. elif name.endswith("ln_q.bias"):
  64. tensor_map['v.post_ln.bias'] = ten
  65. else:
  66. # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
  67. tensor_map[cls.to_gguf_name(name)] = ten
  68. elif 'patch_embed.proj.weight' in name:
  69. # NOTE: split Conv3D into Conv2Ds
  70. c1, c2, kt, kh, kw = ten.shape
  71. assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
  72. tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
  73. tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
  74. else:
  75. tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
  76. for new_name, ten in tensor_map.items():
  77. if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
  78. tensor_map[new_name] = ten.astype(np.float32)
  79. else:
  80. tensor_map[new_name] = ten.astype(dtype)
  81. tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
  82. return tensor_map
  83. class VL25(VL2):
  84. @staticmethod
  85. def to_gguf_name(name: str) -> str:
  86. og = name
  87. name = name.replace("text_model", "t").replace("vision_model", "v")
  88. name = name.replace("blocks", "blk").replace("embeddings.", "")
  89. name = name.replace("attn.", "attn_")
  90. name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
  91. name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
  92. name = name.replace("norm1", "ln1").replace("norm2", "ln2")
  93. name = name.replace("merger.mlp", 'mm')
  94. print(f"[vl25][to_gguf_name] {og} --> {name}")
  95. return name
  96. def main(args):
  97. if args.data_type == 'fp32':
  98. dtype = torch.float32
  99. np_dtype = np.float32
  100. ftype = 0
  101. elif args.data_type == 'fp16':
  102. dtype = torch.float16
  103. np_dtype = np.float16
  104. ftype = 1
  105. else:
  106. raise ValueError()
  107. local_model = False
  108. model_path = ""
  109. model_name = args.model_name
  110. print("model_name: ", model_name)
  111. if args.model_type == "qwen2vl":
  112. qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
  113. model_name, torch_dtype=dtype, device_map="cpu"
  114. )
  115. cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
  116. vcfg = cfg.vision_config
  117. else:
  118. qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
  119. model_name, torch_dtype=dtype, device_map="cpu"
  120. )
  121. cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
  122. vcfg = cfg.vision_config
  123. if os.path.isdir(model_name):
  124. local_model = True
  125. if model_name.endswith(os.sep):
  126. model_name = model_name[:-1]
  127. model_path = model_name
  128. model_name = os.path.basename(model_name)
  129. fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
  130. fout = GGUFWriter(path=fname_out, arch="clip")
  131. fout.add_description("image encoder for Qwen2VL")
  132. fout.add_file_type(ftype)
  133. fout.add_bool("clip.has_text_encoder", False)
  134. fout.add_bool("clip.has_vision_encoder", True)
  135. fout.add_bool("clip.has_qwen2vl_merger", True)
  136. print(cfg.vision_config)
  137. if 'silu' in cfg.vision_config.hidden_act.lower():
  138. fout.add_bool("clip.use_silu", True)
  139. fout.add_bool("clip.use_gelu", False)
  140. elif 'gelu' in cfg.vision_config.hidden_act.lower():
  141. fout.add_bool("clip.use_silu", False)
  142. fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
  143. else:
  144. raise ValueError()
  145. if args.model_type == "qwen2.5vl":
  146. fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
  147. fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
  148. fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
  149. fout.add_string("clip.projector_type", "qwen2.5vl_merger")
  150. else:
  151. fout.add_string("clip.projector_type", "qwen2vl_merger")
  152. fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
  153. fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
  154. if args.model_type == "qwen2.5vl":
  155. tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
  156. else:
  157. tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
  158. for name, data in tensor_map.items():
  159. fout.add_tensor(name, data)
  160. fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
  161. fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
  162. fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
  163. fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
  164. fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
  165. fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
  166. fout.add_name(model_name)
  167. """
  168. HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
  169. it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
  170. """
  171. if local_model:
  172. processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path)
  173. else:
  174. processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
  175. fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
  176. fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
  177. fout.write_header_to_file()
  178. fout.write_kv_data_to_file()
  179. fout.write_tensors_to_file()
  180. fout.close()
  181. print("save model as: ", fname_out)
  182. if __name__ == "__main__":
  183. parser = argparse.ArgumentParser()
  184. parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
  185. parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
  186. parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
  187. args = parser.parse_args()
  188. main(args)