clip-model.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. #pragma once
  2. #include "ggml.h"
  3. #include "clip.h"
  4. #include "clip-impl.h"
  5. #include <array>
  6. #include <vector>
  7. #include <unordered_set>
  8. #include <cstdint>
  9. #include <cmath>
  10. enum ffn_op_type {
  11. FFN_GELU,
  12. FFN_GELU_ERF,
  13. FFN_SILU,
  14. FFN_GELU_QUICK,
  15. };
  16. enum norm_type {
  17. NORM_TYPE_NORMAL,
  18. NORM_TYPE_RMS,
  19. };
  20. enum patch_merge_type {
  21. PATCH_MERGE_FLAT,
  22. PATCH_MERGE_SPATIAL_UNPAD,
  23. };
  24. struct clip_hparams {
  25. int32_t image_size = 0;
  26. int32_t patch_size = 0;
  27. int32_t n_embd = 0;
  28. int32_t n_ff = 0;
  29. int32_t projection_dim = 0;
  30. int32_t n_head = 0;
  31. int32_t n_layer = 0;
  32. // idefics3
  33. int32_t image_longest_edge = 0;
  34. int32_t image_min_pixels = -1;
  35. int32_t image_max_pixels = -1;
  36. int32_t n_merge = 0; // number of patch merges **per-side**
  37. float image_mean[3];
  38. float image_std[3];
  39. // for models using dynamic image size, we need to have a smaller image size to warmup
  40. // otherwise, user will get OOM everytime they load the model
  41. int32_t warmup_image_size = 0;
  42. int32_t warmup_audio_size = 3000;
  43. ffn_op_type ffn_op = FFN_GELU;
  44. patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
  45. float eps = 1e-6;
  46. float rope_theta = 0.0;
  47. std::vector<clip_image_size> image_res_candidates; // for llava-uhd style models
  48. int32_t image_crop_resolution;
  49. std::unordered_set<int32_t> vision_feature_layer;
  50. int32_t attn_window_size = 0;
  51. int32_t n_wa_pattern = 0;
  52. std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
  53. // audio
  54. int32_t n_mel_bins = 0; // whisper preprocessor
  55. int32_t proj_stack_factor = 0; // ultravox
  56. // audio-to-mel preprocessor params
  57. int32_t audio_chunk_len = -1; // in seconds
  58. int32_t audio_sample_rate = -1;
  59. int32_t audio_n_fft = -1;
  60. int32_t audio_window_len = -1;
  61. int32_t audio_hop_len = -1;
  62. // legacy
  63. bool has_llava_projector = false;
  64. int minicpmv_version = 0;
  65. int32_t minicpmv_query_num = 0; // MiniCPM-V query number
  66. // custom value provided by user, can be undefined if not set
  67. int32_t custom_image_min_tokens = -1;
  68. int32_t custom_image_max_tokens = -1;
  69. void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
  70. const int cur_merge = n_merge == 0 ? 1 : n_merge;
  71. const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
  72. image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
  73. image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
  74. warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
  75. }
  76. void set_warmup_n_tokens(int n_tokens) {
  77. int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
  78. GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
  79. const int cur_merge = n_merge == 0 ? 1 : n_merge;
  80. warmup_image_size = n_tok_per_side * patch_size * cur_merge;
  81. // TODO: support warmup size for custom token numbers
  82. }
  83. };
  84. struct clip_layer {
  85. // attention
  86. ggml_tensor * k_w = nullptr;
  87. ggml_tensor * k_b = nullptr;
  88. ggml_tensor * q_w = nullptr;
  89. ggml_tensor * q_b = nullptr;
  90. ggml_tensor * v_w = nullptr;
  91. ggml_tensor * v_b = nullptr;
  92. ggml_tensor * qkv_w = nullptr;
  93. ggml_tensor * qkv_b = nullptr;
  94. ggml_tensor * o_w = nullptr;
  95. ggml_tensor * o_b = nullptr;
  96. ggml_tensor * k_norm = nullptr;
  97. ggml_tensor * q_norm = nullptr;
  98. // layernorm 1
  99. ggml_tensor * ln_1_w = nullptr;
  100. ggml_tensor * ln_1_b = nullptr;
  101. ggml_tensor * ff_up_w = nullptr;
  102. ggml_tensor * ff_up_b = nullptr;
  103. ggml_tensor * ff_gate_w = nullptr;
  104. ggml_tensor * ff_gate_b = nullptr;
  105. ggml_tensor * ff_down_w = nullptr;
  106. ggml_tensor * ff_down_b = nullptr;
  107. // layernorm 2
  108. ggml_tensor * ln_2_w = nullptr;
  109. ggml_tensor * ln_2_b = nullptr;
  110. // layer scale (no bias)
  111. ggml_tensor * ls_1_w = nullptr;
  112. ggml_tensor * ls_2_w = nullptr;
  113. // qwen3vl deepstack merger
  114. ggml_tensor * deepstack_norm_w = nullptr;
  115. ggml_tensor * deepstack_norm_b = nullptr;
  116. ggml_tensor * deepstack_fc1_w = nullptr;
  117. ggml_tensor * deepstack_fc1_b = nullptr;
  118. ggml_tensor * deepstack_fc2_w = nullptr;
  119. ggml_tensor * deepstack_fc2_b = nullptr;
  120. // lfm2
  121. ggml_tensor * ff_norm_w = nullptr;
  122. ggml_tensor * ff_norm_b = nullptr;
  123. ggml_tensor * ff_norm_1_w = nullptr;
  124. ggml_tensor * ff_norm_1_b = nullptr;
  125. ggml_tensor * ff_up_1_w = nullptr;
  126. ggml_tensor * ff_up_1_b = nullptr;
  127. ggml_tensor * ff_down_1_w = nullptr;
  128. ggml_tensor * ff_down_1_b = nullptr;
  129. ggml_tensor * pos_bias_u = nullptr;
  130. ggml_tensor * pos_bias_v = nullptr;
  131. ggml_tensor * norm_conv_w = nullptr;
  132. ggml_tensor * norm_conv_b = nullptr;
  133. ggml_tensor * linear_pos_w = nullptr;
  134. ggml_tensor * conv_norm_w = nullptr;
  135. ggml_tensor * conv_norm_b = nullptr;
  136. ggml_tensor * conv_dw_w = nullptr;
  137. ggml_tensor * conv_dw_b = nullptr;
  138. ggml_tensor * conv_pw1_w = nullptr;
  139. ggml_tensor * conv_pw1_b = nullptr;
  140. ggml_tensor * conv_pw2_w = nullptr;
  141. ggml_tensor * conv_pw2_b = nullptr;
  142. bool has_deepstack() const {
  143. return deepstack_fc1_w != nullptr;
  144. }
  145. };
  146. // Expanded MobileNetV5 block structure for Gemma3n vision encoder
  147. struct mobilenetv5_block {
  148. // Stage 0 (Edge Residual)
  149. ggml_tensor * s0_conv_exp_w = nullptr;
  150. ggml_tensor * s0_bn1_w = nullptr;
  151. ggml_tensor * s0_conv_pwl_w = nullptr;
  152. ggml_tensor * s0_bn2_w = nullptr;
  153. // Stage 1+ (Universal Inverted Residual)
  154. ggml_tensor * dw_start_w = nullptr;
  155. ggml_tensor * dw_start_bn_w = nullptr;
  156. ggml_tensor * pw_exp_w = nullptr;
  157. ggml_tensor * pw_exp_bn_w = nullptr;
  158. ggml_tensor * dw_mid_w = nullptr;
  159. ggml_tensor * dw_mid_bn_w = nullptr;
  160. ggml_tensor * pw_proj_w = nullptr;
  161. ggml_tensor * pw_proj_bn_w = nullptr;
  162. ggml_tensor * layer_scale_w = nullptr;
  163. // Attention (MQA) components
  164. ggml_tensor * attn_q_w = nullptr;
  165. ggml_tensor * attn_k_w = nullptr;
  166. ggml_tensor * attn_v_w = nullptr;
  167. ggml_tensor * attn_o_w = nullptr;
  168. // Optional downsampling/norm in attention
  169. ggml_tensor * attn_k_dw_w = nullptr;
  170. ggml_tensor * attn_k_norm_w = nullptr;
  171. ggml_tensor * attn_v_dw_w = nullptr;
  172. ggml_tensor * attn_v_norm_w = nullptr;
  173. // Block norm (often present in attention blocks)
  174. ggml_tensor * attn_norm_w = nullptr;
  175. };
  176. struct clip_model {
  177. clip_modality modality = CLIP_MODALITY_VISION;
  178. projector_type proj_type = PROJECTOR_TYPE_MLP;
  179. clip_hparams hparams;
  180. // embeddings
  181. ggml_tensor * class_embedding = nullptr;
  182. ggml_tensor * patch_embeddings_0 = nullptr;
  183. ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
  184. ggml_tensor * patch_bias = nullptr;
  185. ggml_tensor * position_embeddings = nullptr;
  186. ggml_tensor * norm_embd_w = nullptr;
  187. ggml_tensor * norm_embd_b = nullptr;
  188. ggml_tensor * pre_ln_w = nullptr;
  189. ggml_tensor * pre_ln_b = nullptr;
  190. std::vector<clip_layer> layers;
  191. int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer
  192. ggml_tensor * post_ln_w;
  193. ggml_tensor * post_ln_b;
  194. ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
  195. ggml_tensor * mm_fc_w;
  196. ggml_tensor * mm_fc_b;
  197. ggml_tensor * mm_ffn_up_w = nullptr;
  198. ggml_tensor * mm_ffn_up_b = nullptr;
  199. ggml_tensor * mm_ffn_gate_w = nullptr;
  200. ggml_tensor * mm_ffn_gate_b = nullptr;
  201. ggml_tensor * mm_ffn_down_w = nullptr;
  202. ggml_tensor * mm_ffn_down_b = nullptr;
  203. ggml_tensor * mm_post_norm_w = nullptr;
  204. ggml_tensor * mm_post_norm_b = nullptr;
  205. // LLaVA projection
  206. ggml_tensor * mm_input_norm_w = nullptr;
  207. ggml_tensor * mm_input_norm_b = nullptr;
  208. ggml_tensor * mm_0_w = nullptr;
  209. ggml_tensor * mm_0_b = nullptr;
  210. ggml_tensor * mm_2_w = nullptr;
  211. ggml_tensor * mm_2_b = nullptr;
  212. ggml_tensor * image_newline = nullptr;
  213. // Yi type models with mlp+normalization projection
  214. ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
  215. ggml_tensor * mm_1_b = nullptr;
  216. ggml_tensor * mm_3_w = nullptr;
  217. ggml_tensor * mm_3_b = nullptr;
  218. ggml_tensor * mm_4_w = nullptr;
  219. ggml_tensor * mm_4_b = nullptr;
  220. // GLMV-Edge projection
  221. ggml_tensor * mm_model_adapter_conv_w = nullptr;
  222. ggml_tensor * mm_model_adapter_conv_b = nullptr;
  223. // MobileVLM projection
  224. ggml_tensor * mm_model_mlp_1_w = nullptr;
  225. ggml_tensor * mm_model_mlp_1_b = nullptr;
  226. ggml_tensor * mm_model_mlp_3_w = nullptr;
  227. ggml_tensor * mm_model_mlp_3_b = nullptr;
  228. ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
  229. ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
  230. ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
  231. ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
  232. ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
  233. ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
  234. ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
  235. ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
  236. ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
  237. ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
  238. ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
  239. ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
  240. ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
  241. ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
  242. ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
  243. ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
  244. ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
  245. ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
  246. ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
  247. ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
  248. // MobileVLM_V2 projection
  249. ggml_tensor * mm_model_mlp_0_w = nullptr;
  250. ggml_tensor * mm_model_mlp_0_b = nullptr;
  251. ggml_tensor * mm_model_mlp_2_w = nullptr;
  252. ggml_tensor * mm_model_mlp_2_b = nullptr;
  253. ggml_tensor * mm_model_peg_0_w = nullptr;
  254. ggml_tensor * mm_model_peg_0_b = nullptr;
  255. // MINICPMV projection
  256. ggml_tensor * mm_model_pos_embed_k = nullptr;
  257. ggml_tensor * mm_model_query = nullptr;
  258. ggml_tensor * mm_model_proj = nullptr;
  259. ggml_tensor * mm_model_kv_proj = nullptr;
  260. ggml_tensor * mm_model_attn_q_w = nullptr;
  261. ggml_tensor * mm_model_attn_q_b = nullptr;
  262. ggml_tensor * mm_model_attn_k_w = nullptr;
  263. ggml_tensor * mm_model_attn_k_b = nullptr;
  264. ggml_tensor * mm_model_attn_v_w = nullptr;
  265. ggml_tensor * mm_model_attn_v_b = nullptr;
  266. ggml_tensor * mm_model_attn_o_w = nullptr;
  267. ggml_tensor * mm_model_attn_o_b = nullptr;
  268. ggml_tensor * mm_model_ln_q_w = nullptr;
  269. ggml_tensor * mm_model_ln_q_b = nullptr;
  270. ggml_tensor * mm_model_ln_kv_w = nullptr;
  271. ggml_tensor * mm_model_ln_kv_b = nullptr;
  272. ggml_tensor * mm_model_ln_post_w = nullptr;
  273. ggml_tensor * mm_model_ln_post_b = nullptr;
  274. // gemma3
  275. ggml_tensor * mm_input_proj_w = nullptr;
  276. ggml_tensor * mm_soft_emb_norm_w = nullptr;
  277. // mobilenetv5 for gemma3n
  278. std::vector<mobilenetv5_block> mobilenet_blocks;
  279. std::vector<int> mobilenet_stage_ends;
  280. ggml_tensor * mobilenet_stem_conv_w = nullptr;
  281. ggml_tensor * mobilenet_stem_conv_b = nullptr;
  282. ggml_tensor * mobilenet_stem_norm_w = nullptr;
  283. ggml_tensor * mm_post_proj_norm_w = nullptr;
  284. // Multi-Scale Fusion Adapter (MSFA) components
  285. ggml_tensor * msfa_concat_conv_w = nullptr;
  286. ggml_tensor * msfa_concat_norm_w = nullptr;
  287. ggml_tensor * msfa_ffn_expand_w = nullptr;
  288. ggml_tensor * msfa_ffn_project_w = nullptr;
  289. ggml_tensor * msfa_ffn_expand_bn = nullptr;
  290. ggml_tensor * msfa_ffn_project_bn = nullptr;
  291. // pixtral, glm4v
  292. ggml_tensor * token_embd_img_break = nullptr;
  293. ggml_tensor * mm_patch_merger_w = nullptr;
  294. ggml_tensor * mm_patch_merger_b = nullptr;
  295. // ultravox / whisper encoder
  296. ggml_tensor * conv1d_1_w = nullptr;
  297. ggml_tensor * conv1d_1_b = nullptr;
  298. ggml_tensor * conv1d_2_w = nullptr;
  299. ggml_tensor * conv1d_2_b = nullptr;
  300. ggml_tensor * mm_norm_pre_w = nullptr;
  301. ggml_tensor * mm_norm_pre_b = nullptr;
  302. ggml_tensor * mm_norm_mid_w = nullptr;
  303. // cogvlm
  304. ggml_tensor * mm_post_fc_norm_w = nullptr;
  305. ggml_tensor * mm_post_fc_norm_b = nullptr;
  306. ggml_tensor * mm_h_to_4h_w = nullptr;
  307. ggml_tensor * mm_gate_w = nullptr;
  308. ggml_tensor * mm_4h_to_h_w = nullptr;
  309. ggml_tensor * mm_boi = nullptr;
  310. ggml_tensor * mm_eoi = nullptr;
  311. // lfm2 audio
  312. std::array<ggml_tensor *, 7> pre_encode_conv_X_w = {nullptr};
  313. std::array<ggml_tensor *, 7> pre_encode_conv_X_b = {nullptr};
  314. ggml_tensor * pre_encode_out_w = nullptr;
  315. ggml_tensor * pre_encode_out_b = nullptr;
  316. bool audio_has_avgpool() const {
  317. return proj_type == PROJECTOR_TYPE_QWEN2A
  318. || proj_type == PROJECTOR_TYPE_VOXTRAL
  319. || proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO;
  320. }
  321. bool audio_has_stack_frames() const {
  322. return proj_type == PROJECTOR_TYPE_ULTRAVOX
  323. || proj_type == PROJECTOR_TYPE_VOXTRAL;
  324. }
  325. };
  326. const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx);