clip-graph.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cpp.h"
  4. #include "clip.h"
  5. #include "clip-impl.h"
  6. #include "clip-model.h"
  7. #include <vector>
  8. #include <functional>
  9. #define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)
  10. struct clip_graph {
  11. const clip_model & model;
  12. const clip_hparams & hparams;
  13. projector_type proj_type;
  14. // we only support single image per batch
  15. const clip_image_f32 & img;
  16. const int patch_size;
  17. const int n_patches_x;
  18. const int n_patches_y;
  19. const int n_patches;
  20. const int n_embd;
  21. const int n_head;
  22. const int d_head;
  23. const int n_layer;
  24. const int n_mmproj_embd;
  25. const float eps;
  26. const float kq_scale;
  27. const clip_flash_attn_type flash_attn_type;
  28. // for debugging
  29. const bool debug_graph;
  30. std::vector<ggml_tensor *> & debug_print_tensors;
  31. ggml_context_ptr ctx0_ptr;
  32. ggml_context * ctx0;
  33. ggml_cgraph * gf;
  34. clip_graph(clip_ctx * ctx, const clip_image_f32 & img);
  35. virtual ~clip_graph() = default;
  36. virtual ggml_cgraph * build() = 0;
  37. //
  38. // utility functions
  39. //
  40. void cb(ggml_tensor * cur0, const char * name, int il) const;
  41. // siglip2 naflex
  42. ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE);
  43. // build vision transformer (ViT) cgraph
  44. // this function should cover most of the models
  45. // if your model has specific features, you should probably duplicate this function
  46. ggml_tensor * build_vit(
  47. ggml_tensor * inp,
  48. int64_t n_pos,
  49. norm_type norm_t,
  50. ffn_op_type ffn_t,
  51. ggml_tensor * learned_pos_embd,
  52. std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos);
  53. // build the input after conv2d (inp_raw --> patches)
  54. // returns tensor with shape [n_embd, n_patches]
  55. ggml_tensor * build_inp();
  56. ggml_tensor * build_inp_raw(int channels = 3);
  57. ggml_tensor * build_norm(
  58. ggml_tensor * cur,
  59. ggml_tensor * mw,
  60. ggml_tensor * mb,
  61. norm_type type,
  62. float norm_eps,
  63. int il) const;
  64. ggml_tensor * build_ffn(
  65. ggml_tensor * cur,
  66. ggml_tensor * up,
  67. ggml_tensor * up_b,
  68. ggml_tensor * gate,
  69. ggml_tensor * gate_b,
  70. ggml_tensor * down,
  71. ggml_tensor * down_b,
  72. ffn_op_type type_op,
  73. int il) const;
  74. ggml_tensor * build_attn(
  75. ggml_tensor * wo,
  76. ggml_tensor * wo_b,
  77. ggml_tensor * q_cur,
  78. ggml_tensor * k_cur,
  79. ggml_tensor * v_cur,
  80. ggml_tensor * kq_mask,
  81. float kq_scale,
  82. int il) const;
  83. // implementation of the 2D RoPE without adding a new op in ggml
  84. // this is not efficient (use double the memory), but works on all backends
  85. // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
  86. ggml_tensor * build_rope_2d(
  87. ggml_context * ctx0,
  88. ggml_tensor * cur,
  89. ggml_tensor * pos_a, // first half
  90. ggml_tensor * pos_b, // second half
  91. const float freq_base,
  92. const bool interleave_freq
  93. );
  94. // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
  95. // support dynamic resolution
  96. ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor);
  97. // Generic function to stack frames for audio processing
  98. // Abstracts out the StackAudioFrames logic used by ultravox
  99. ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed);
  100. };