minicpmv.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_minicpmv::build() {
  3. GGML_ASSERT(model.class_embedding == nullptr);
  4. const int n_pos = n_patches;
  5. const int n_embd_proj = n_mmproj_embd;
  6. // position embeddings for the projector (not for ViT)
  7. // see: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/resampler.py#L70
  8. // base frequency omega
  9. ggml_tensor * omega = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_proj / 4);
  10. ggml_set_name(omega, "omega");
  11. ggml_set_input(omega);
  12. // 2D input positions (using float for sinusoidal embeddings)
  13. ggml_tensor * pos_h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos);
  14. ggml_set_name(pos_h, "pos_h");
  15. ggml_set_input(pos_h);
  16. ggml_tensor * pos_w = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos);
  17. ggml_set_name(pos_w, "pos_w");
  18. ggml_set_input(pos_w);
  19. // for selecting learned pos embd, used by ViT
  20. struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
  21. ggml_set_name(positions, "positions");
  22. ggml_set_input(positions);
  23. ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
  24. ggml_tensor * inp = build_inp();
  25. ggml_tensor * embeddings = build_vit(
  26. inp, n_pos,
  27. NORM_TYPE_NORMAL,
  28. hparams.ffn_op,
  29. learned_pos_embd,
  30. nullptr);
  31. // resampler projector (it is just another transformer)
  32. ggml_tensor * q = model.mm_model_query;
  33. ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
  34. // norm
  35. q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
  36. v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
  37. // calculate sinusoidal pos embd
  38. ggml_tensor * pos_embed = nullptr;
  39. {
  40. // outer product
  41. ggml_tensor * omega_b = ggml_repeat_4d(ctx0, omega, omega->ne[0], n_pos, 1, 1); // n_pos rows
  42. ggml_tensor * theta_x = ggml_mul(ctx0, omega_b, pos_w);
  43. ggml_tensor * theta_y = ggml_mul(ctx0, omega_b, pos_h);
  44. // sin and cos
  45. ggml_tensor * pos_embd_x = ggml_concat(
  46. ctx0,
  47. ggml_sin(ctx0, theta_x),
  48. ggml_cos(ctx0, theta_x),
  49. 0 // concat on first dim
  50. );
  51. ggml_tensor * pos_embd_y = ggml_concat(
  52. ctx0,
  53. ggml_sin(ctx0, theta_y),
  54. ggml_cos(ctx0, theta_y),
  55. 0 // concat on first dim
  56. );
  57. pos_embed = ggml_concat(ctx0, pos_embd_x, pos_embd_y, 0);
  58. }
  59. // k = v + pos_embed
  60. ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
  61. // attention
  62. {
  63. const int d_head = 128;
  64. int n_head = n_embd_proj/d_head;
  65. // Use actual config value if available, otherwise fall back to hardcoded values
  66. int num_query = hparams.minicpmv_query_num;
  67. ggml_tensor * Q = ggml_add(ctx0,
  68. ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
  69. model.mm_model_attn_q_b);
  70. ggml_tensor * K = ggml_add(ctx0,
  71. ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
  72. model.mm_model_attn_k_b);
  73. ggml_tensor * V = ggml_add(ctx0,
  74. ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
  75. model.mm_model_attn_v_b);
  76. Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
  77. K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
  78. V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
  79. cb(Q, "resampler_Q", -1);
  80. cb(K, "resampler_K", -1);
  81. cb(V, "resampler_V", -1);
  82. float resampler_kq_scale = 1.0f/ sqrtf(float(d_head));
  83. embeddings = build_attn(
  84. model.mm_model_attn_o_w,
  85. model.mm_model_attn_o_b,
  86. Q, K, V, nullptr, resampler_kq_scale, -1);
  87. cb(embeddings, "resampler_attn_out", -1);
  88. }
  89. // layernorm
  90. embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
  91. // projection
  92. embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
  93. // build the graph
  94. ggml_build_forward_expand(gf, embeddings);
  95. return gf;
  96. }