siglip.cpp 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_siglip::build() {
  3. ggml_tensor * inp = build_inp();
  4. ggml_tensor * learned_pos_embd = model.position_embeddings;
  5. if (proj_type == PROJECTOR_TYPE_LFM2) {
  6. learned_pos_embd = resize_position_embeddings();
  7. }
  8. ggml_tensor * cur = build_vit(
  9. inp, n_patches,
  10. NORM_TYPE_NORMAL,
  11. hparams.ffn_op,
  12. learned_pos_embd,
  13. nullptr);
  14. if (proj_type == PROJECTOR_TYPE_GEMMA3) {
  15. const int batch_size = 1;
  16. GGML_ASSERT(n_patches_x == n_patches_y);
  17. const int patches_per_image = n_patches_x;
  18. const int kernel_size = hparams.n_merge;
  19. cur = ggml_transpose(ctx0, cur);
  20. cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
  21. // doing a pool2d to reduce the number of output tokens
  22. cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
  23. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
  24. cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
  25. // apply norm before projection
  26. cur = ggml_rms_norm(ctx0, cur, eps);
  27. cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
  28. // apply projection
  29. cur = ggml_mul_mat(ctx0,
  30. ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
  31. cur);
  32. } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
  33. // pixel_shuffle
  34. // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
  35. const int scale_factor = model.hparams.n_merge;
  36. cur = build_patch_merge_permute(cur, scale_factor);
  37. cur = ggml_mul_mat(ctx0, model.projection, cur);
  38. } else if (proj_type == PROJECTOR_TYPE_LFM2) {
  39. // pixel unshuffle block
  40. const int scale_factor = model.hparams.n_merge;
  41. cur = build_patch_merge_permute(cur, scale_factor);
  42. // projection, in LFM2-VL input norm is optional
  43. if (model.mm_input_norm_w) {
  44. cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
  45. cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
  46. }
  47. if (model.mm_input_norm_b) {
  48. cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
  49. }
  50. cur = build_ffn(cur,
  51. model.mm_1_w, model.mm_1_b,
  52. nullptr, nullptr,
  53. model.mm_2_w, model.mm_2_b,
  54. FFN_GELU,
  55. -1);
  56. } else if (proj_type == PROJECTOR_TYPE_JANUS_PRO) {
  57. cur = build_ffn(cur,
  58. model.mm_0_w, model.mm_0_b,
  59. nullptr, nullptr,
  60. model.mm_1_w, model.mm_1_b,
  61. hparams.ffn_op,
  62. -1);
  63. } else {
  64. GGML_ABORT("SigLIP: Unsupported projector type");
  65. }
  66. // build the graph
  67. ggml_build_forward_expand(gf, cur);
  68. return gf;
  69. }