mobilenetv5.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #include "models.h"
  2. // Helpers for MobileNetV5 Blocks
  3. // RMS Norm 2D - normalizes over channels for each spatial position
  4. ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) {
  5. // inp: [W, H, C, B]
  6. ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3);
  7. cur = ggml_cont(ctx0, cur);
  8. cur = ggml_rms_norm(ctx0, cur, eps);
  9. if (weight) {
  10. cur = ggml_mul(ctx0, cur, weight);
  11. }
  12. cur = ggml_permute(ctx0, cur, 2, 1, 0, 3);
  13. cur = ggml_cont(ctx0, cur);
  14. return cur;
  15. }
  16. // Conv2dSame padding - asymmetric SAME padding like PyTorch/TF
  17. ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) {
  18. const int64_t ih = inp->ne[1]; // height
  19. const int64_t iw = inp->ne[0]; // width
  20. // Calculate output size (ceil division)
  21. const int64_t oh = (ih + stride_h - 1) / stride_h;
  22. const int64_t ow = (iw + stride_w - 1) / stride_w;
  23. // Calculate padding needed
  24. const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih);
  25. const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw);
  26. // Split padding asymmetrically
  27. const int pad_h_top = pad_h / 2;
  28. const int pad_h_bottom = pad_h - pad_h_top;
  29. const int pad_w_left = pad_w / 2;
  30. const int pad_w_right = pad_w - pad_w_left;
  31. // Apply padding if needed
  32. // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
  33. // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch
  34. if (pad_h > 0 || pad_w > 0) {
  35. inp = ggml_pad_ext(ctx0, inp,
  36. pad_w_left, pad_w_right, // width padding (dim 0)
  37. pad_h_top, pad_h_bottom, // height padding (dim 1)
  38. 0, 0, // no channel padding (dim 2)
  39. 0, 0); // no batch padding (dim 3)
  40. }
  41. return inp;
  42. }
  43. // Edge Residual Block (Stage 0)
  44. ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
  45. ggml_tensor * cur = inp;
  46. // 1. Expansion Conv (3x3)
  47. if (stride == 2) {
  48. // Case: Downsampling (Block 0)
  49. // Replicates Conv2dSame(kernel=3, stride=2)
  50. cur = pad_same_2d(cur, 3, 3, stride, stride);
  51. cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1);
  52. } else {
  53. // Case: Normal 3x3 Block (Block 1, 2)
  54. // Replicates Conv2d(kernel=3, stride=1, padding=1)
  55. cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1);
  56. }
  57. // BN + Activation
  58. if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w);
  59. cur = ggml_gelu(ctx0, cur);
  60. // 2. Pointwise Linear Conv (1x1)
  61. // 1x1 Convs usually have padding=0 and stride=1
  62. cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1);
  63. if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w);
  64. // 3. Residual Connection
  65. // Only apply residual if spatial dimensions and channels match (stride 1)
  66. if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) {
  67. cur = ggml_add(ctx0, cur, inp);
  68. }
  69. return cur;
  70. }
  71. // Universal Inverted Residual Block (Stage 1+)
  72. ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
  73. ggml_tensor * cur = inp;
  74. // 1. Depthwise Start (Optional)
  75. // NOTE: dw_start always has stride=1 (no downsampling here)
  76. if (block.dw_start_w) {
  77. int k = block.dw_start_w->ne[0]; // 3 or 5
  78. int p = k / 2;
  79. cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1);
  80. if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w);
  81. }
  82. // 2. Pointwise Expansion (1x1)
  83. if (block.pw_exp_w) {
  84. // Standard 1x1 conv, pad=0, stride=1
  85. cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1);
  86. if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w);
  87. cur = ggml_gelu(ctx0, cur);
  88. }
  89. // 3. Depthwise Mid (Optional)
  90. // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage)
  91. if (block.dw_mid_w) {
  92. int k = block.dw_mid_w->ne[0]; // 3 or 5
  93. if (stride > 1) {
  94. // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding
  95. cur = pad_same_2d(cur, k, k, stride, stride);
  96. cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0
  97. } else {
  98. // Case: Stride 1 -> Use Standard Symmetric Padding
  99. int p = k / 2;
  100. cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1);
  101. }
  102. if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w);
  103. cur = ggml_gelu(ctx0, cur);
  104. }
  105. // 4. Pointwise Projection (1x1)
  106. if (block.pw_proj_w) {
  107. cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1);
  108. if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w);
  109. }
  110. // Apply Layer Scaling if present
  111. if (block.layer_scale_w) {
  112. cur = ggml_mul(ctx0, cur, block.layer_scale_w);
  113. }
  114. // 5. Residual Connection
  115. bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]);
  116. bool same_channel = (inp->ne[2] == cur->ne[2]);
  117. if (same_spatial && same_channel) {
  118. cur = ggml_add(ctx0, cur, inp);
  119. }
  120. return cur;
  121. }
  122. // Attention Block (MQA)
  123. ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) {
  124. ggml_tensor * cur = inp;
  125. // Norm
  126. if (block.attn_norm_w) {
  127. cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f);
  128. }
  129. // 1. Q Calculation
  130. ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1);
  131. // 2. K Calculation (Downsampled)
  132. // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
  133. ggml_tensor * k_inp = cur;
  134. if (block.attn_k_dw_w) {
  135. int k_size = block.attn_k_dw_w->ne[0]; // Usually 3
  136. k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding
  137. k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0
  138. if (block.attn_k_norm_w) {
  139. k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f);
  140. }
  141. }
  142. ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1);
  143. // 3. V Calculation (Downsampled)
  144. // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
  145. ggml_tensor * v_inp = cur;
  146. if (block.attn_v_dw_w) {
  147. int v_size = block.attn_v_dw_w->ne[0]; // Usually 3
  148. v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding
  149. v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0
  150. if (block.attn_v_norm_w) {
  151. v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f);
  152. }
  153. }
  154. ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1);
  155. const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3];
  156. const int D = k->ne[2]; // Head dimension
  157. const int n_head = q->ne[2] / D;
  158. const int N = W * H;
  159. // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B]
  160. q = ggml_reshape_3d(ctx0, q, N, D*n_head, B);
  161. q = ggml_reshape_4d(ctx0, q, N, D, n_head, B);
  162. q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B]
  163. q = ggml_cont(ctx0, q);
  164. const int Wk = k->ne[0]; const int Hk = k->ne[1];
  165. const int M = Wk * Hk;
  166. // Process K: [Wk, Hk, D, B] -> [D, M, 1, B]
  167. k = ggml_reshape_3d(ctx0, k, M, D, B);
  168. k = ggml_reshape_4d(ctx0, k, M, D, 1, B);
  169. k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B]
  170. k = ggml_cont(ctx0, k);
  171. // Process V: [Wk, Hk, D, B] -> [M, D, 1, B]
  172. v = ggml_reshape_3d(ctx0, v, M, D, B);
  173. v = ggml_reshape_4d(ctx0, v, M, D, 1, B);
  174. v = ggml_cont(ctx0, v); // [M, D, 1, B]
  175. // Multi-Query Attention
  176. float scale = 1.0f / sqrtf((float)D);
  177. // Step 1: Compute Q @ K.T
  178. ggml_tensor * scores = ggml_mul_mat(ctx0, k, q);
  179. scores = ggml_scale(ctx0, scores, scale);
  180. scores = ggml_soft_max(ctx0, scores);
  181. ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores);
  182. kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3);
  183. kqv = ggml_cont(ctx0, kqv);
  184. kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B);
  185. kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B);
  186. kqv = ggml_cont(ctx0, kqv);
  187. // Output projection
  188. cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1);
  189. // Residual & Layer Scale
  190. if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) {
  191. if (block.layer_scale_w) {
  192. cur = ggml_mul(ctx0, cur, block.layer_scale_w);
  193. }
  194. cur = ggml_add(ctx0, cur, inp);
  195. }
  196. return cur;
  197. }
  198. ggml_cgraph * clip_graph_mobilenetv5::build() {
  199. ggml_tensor * inp = build_inp_raw();
  200. // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2))
  201. ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding
  202. cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0
  203. if (model.mobilenet_stem_conv_b) {
  204. cur = ggml_add(ctx0, cur, model.mobilenet_stem_conv_b);
  205. }
  206. if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w);
  207. cur = ggml_gelu(ctx0, cur);
  208. // 2. Blocks
  209. std::vector<ggml_tensor*> intermediate_features;
  210. const int total_blocks = model.mobilenet_blocks.size();
  211. auto is_stage_start = [&](int i) {
  212. if (i == 0) return true;
  213. for (int end_idx : model.mobilenet_stage_ends) {
  214. if (i == end_idx + 1) return true;
  215. }
  216. return false;
  217. };
  218. auto is_fusion_point = [&](int i) {
  219. if (model.mobilenet_stage_ends.size() >= 4) {
  220. if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2
  221. if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3
  222. } else {
  223. if (i == total_blocks - 1) return true;
  224. }
  225. return false;
  226. };
  227. for (int i = 0; i < total_blocks; i++) {
  228. const auto & block = model.mobilenet_blocks[i];
  229. int stride = is_stage_start(i) ? 2 : 1;
  230. if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride);
  231. else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block);
  232. else cur = build_inverted_residual(cur, block, stride);
  233. if (is_fusion_point(i)) {
  234. intermediate_features.push_back(cur);
  235. }
  236. }
  237. // 3. Multi-Scale Fusion Adapter (MSFA)
  238. if (!intermediate_features.empty()) {
  239. // A. Reference Resolution: PyTorch implementation uses inputs[0]
  240. // We assume intermediate_features[0] is the "High Resolution" target.
  241. // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32).
  242. ggml_tensor* target_feat = intermediate_features[0];
  243. int high_res_w = target_feat->ne[0];
  244. int high_res_h = target_feat->ne[1];
  245. std::vector<ggml_tensor*> resized_feats;
  246. // B. Resize inputs to match inputs[0] (High Resolution)
  247. for (auto feat : intermediate_features) {
  248. int feat_w = feat->ne[0];
  249. int feat_h = feat->ne[1];
  250. // PyTorch: if feat_size < high_resolution: interpolate
  251. if (feat_w < high_res_w || feat_h < high_res_h) {
  252. // Calculate scale factor.
  253. // Note: PyTorch 'nearest' works on arbitrary float scales.
  254. // ggml_upscale generally takes integer factors or target sizes depending on helper.
  255. // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2).
  256. int scale_w = high_res_w / feat_w;
  257. // int scale_h = high_res_h / feat_h;
  258. // Safety check for non-integer scaling if strictly replicating
  259. GGML_ASSERT(high_res_w % feat_w == 0);
  260. // Upsample (Nearest Neighbor)
  261. // 2 is the scale factor
  262. feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST);
  263. }
  264. resized_feats.push_back(feat);
  265. }
  266. // C. Concatenate at High Resolution (Channel Dim = 2 in ggml)
  267. cur = resized_feats[0];
  268. for (size_t k = 1; k < resized_feats.size(); ++k) {
  269. cur = ggml_concat(ctx0, cur, resized_feats[k], 2);
  270. }
  271. // D. FFN (UniversalInvertedResidual)
  272. // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm
  273. // 1. Expansion
  274. if (model.msfa_ffn_expand_w) {
  275. // 1x1 Conv
  276. cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1);
  277. if (model.msfa_ffn_expand_bn) {
  278. cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn);
  279. }
  280. cur = ggml_gelu(ctx0, cur);
  281. }
  282. // 2. Projection (No DW because kernel_size=0)
  283. if (model.msfa_ffn_project_w) {
  284. // 1x1 Conv
  285. cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1);
  286. // UniversalInvertedResidual typically has a norm after projection
  287. if (model.msfa_ffn_project_bn) {
  288. cur = rms_norm_2d(cur, model.msfa_ffn_project_bn);
  289. }
  290. }
  291. // E. Final Downsample to Target Resolution (Output Resolution)
  292. // PyTorch: matches self.output_resolution (e.g. 16x16)
  293. const int target_out_res = 16;
  294. int current_w = cur->ne[0];
  295. if (current_w > target_out_res) {
  296. int s = current_w / target_out_res;
  297. GGML_ASSERT(current_w % target_out_res == 0);
  298. // Avg Pool: Kernel=s, Stride=s
  299. cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0);
  300. }
  301. // F. Final Norm
  302. if (model.msfa_concat_norm_w) {
  303. cur = rms_norm_2d(cur, model.msfa_concat_norm_w);
  304. }
  305. }
  306. // 4. Gemma 3n Multimodal Projection (Embedder)
  307. // Input: 'cur' is [Width, Height, Channels, Batch]
  308. int W = cur->ne[0];
  309. int H = cur->ne[1];
  310. int C = cur->ne[2];
  311. int B = cur->ne[3];
  312. GGML_ASSERT(C == hparams.n_embd);
  313. // 1. Permute and Flatten to [Channels, Tokens, Batch]
  314. // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch)
  315. cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B]
  316. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B]
  317. cur = ggml_cont(ctx0, cur);
  318. cur = ggml_reshape_3d(ctx0, cur, C, W*H, B);
  319. cur = ggml_cont(ctx0, cur);
  320. // 2. FEATURE SCALING
  321. // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5
  322. const float scale_factor = sqrtf((float)C);
  323. cur = ggml_scale(ctx0, cur, scale_factor);
  324. // 3. SOFT EMBEDDING NORM
  325. // PyTorch: self._norm(x) * self.weight
  326. // We must normalize regardless, then multiply if weight exists.
  327. {
  328. const float eps = 1e-6f; // Gemma3n uses 1e-6
  329. cur = ggml_rms_norm(ctx0, cur, eps);
  330. if (model.mm_soft_emb_norm_w) {
  331. // Weight shape is (2048,) -> Element-wise broadcast multiply
  332. cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
  333. }
  334. }
  335. // 4. PROJECTION
  336. // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
  337. // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
  338. if (model.mm_input_proj_w) {
  339. cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
  340. }
  341. // 5. POST PROJECTION NORM
  342. // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False)
  343. // with_scale=False means weight is registered as buffer with value 1.0
  344. // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1
  345. {
  346. const float eps = 1e-6f;
  347. cur = ggml_rms_norm(ctx0, cur, eps);
  348. if (model.mm_post_proj_norm_w) {
  349. // If weight is loaded, multiply (should be ~1.0 anyway)
  350. cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w);
  351. }
  352. }
  353. ggml_build_forward_expand(gf, cur);
  354. return gf;
  355. }