siglip.cpp 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_siglip::build() {
  3. ggml_tensor * inp = build_inp();
  4. ggml_tensor * learned_pos_embd = model.position_embeddings;
  5. if (proj_type == PROJECTOR_TYPE_LFM2) {
  6. learned_pos_embd = resize_position_embeddings();
  7. }
  8. ggml_tensor * cur = build_vit(
  9. inp, n_patches,
  10. NORM_TYPE_NORMAL,
  11. hparams.ffn_op,
  12. learned_pos_embd,
  13. nullptr);
  14. if (proj_type == PROJECTOR_TYPE_GEMMA3) {
  15. const int batch_size = 1;
  16. GGML_ASSERT(n_patches_x == n_patches_y);
  17. const int patches_per_image = n_patches_x;
  18. const int kernel_size = hparams.n_merge;
  19. cur = ggml_transpose(ctx0, cur);
  20. cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
  21. // doing a pool2d to reduce the number of output tokens
  22. cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
  23. cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
  24. cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
  25. // apply norm before projection
  26. cur = ggml_rms_norm(ctx0, cur, eps);
  27. cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
  28. // apply projection
  29. cur = ggml_mul_mat(ctx0,
  30. ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
  31. cur);
  32. } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
  33. // pixel_shuffle
  34. // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
  35. const int scale_factor = model.hparams.n_merge;
  36. cur = build_patch_merge_permute(cur, scale_factor);
  37. cur = ggml_mul_mat(ctx0, model.projection, cur);
  38. } else if (proj_type == PROJECTOR_TYPE_LFM2) {
  39. // pixel unshuffle block
  40. const int scale_factor = model.hparams.n_merge;
  41. cur = build_patch_merge_permute(cur, scale_factor);
  42. // projection
  43. cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
  44. cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
  45. cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
  46. cur = build_ffn(cur,
  47. model.mm_1_w, model.mm_1_b,
  48. nullptr, nullptr,
  49. model.mm_2_w, model.mm_2_b,
  50. FFN_GELU,
  51. -1);
  52. } else if (proj_type == PROJECTOR_TYPE_JANUS_PRO) {
  53. cur = build_ffn(cur,
  54. model.mm_0_w, model.mm_0_b,
  55. nullptr, nullptr,
  56. model.mm_1_w, model.mm_1_b,
  57. hparams.ffn_op,
  58. -1);
  59. } else {
  60. GGML_ABORT("SigLIP: Unsupported projector type");
  61. }
  62. // build the graph
  63. ggml_build_forward_expand(gf, cur);
  64. return gf;
  65. }