glm4v.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_glm4v::build() {
  3. GGML_ASSERT(model.patch_bias != nullptr);
  4. GGML_ASSERT(model.position_embeddings != nullptr);
  5. GGML_ASSERT(model.class_embedding == nullptr);
  6. const int batch_size = 1;
  7. norm_type norm_t = NORM_TYPE_RMS;
  8. ggml_tensor * inp_raw = build_inp_raw();
  9. ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  10. int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
  11. ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4);
  12. ggml_set_name(positions, "positions");
  13. ggml_set_input(positions);
  14. GGML_ASSERT(img.nx % (patch_size * 2) == 0);
  15. GGML_ASSERT(img.ny % (patch_size * 2) == 0);
  16. // second conv dimension
  17. {
  18. auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  19. inp = ggml_add(ctx0, inp, inp_1);
  20. inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
  21. inp = ggml_cont_4d(
  22. ctx0, inp,
  23. n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
  24. inp = ggml_reshape_4d(
  25. ctx0, inp,
  26. n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
  27. inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
  28. inp = ggml_cont_3d(
  29. ctx0, inp,
  30. n_embd, n_patches_x * n_patches_y, batch_size);
  31. }
  32. // add patch bias
  33. inp = ggml_add(ctx0, inp, model.patch_bias);
  34. cb(inp, "patch_bias", -1);
  35. // pos-conv norm
  36. inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
  37. // calculate absolute position embedding and apply
  38. ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
  39. learned_pos_embd = ggml_cont_4d(
  40. ctx0, learned_pos_embd,
  41. n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
  42. learned_pos_embd = ggml_reshape_4d(
  43. ctx0, learned_pos_embd,
  44. n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
  45. learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
  46. learned_pos_embd = ggml_cont_3d(
  47. ctx0, learned_pos_embd,
  48. n_embd, n_patches_x * n_patches_y, batch_size);
  49. cb(learned_pos_embd, "learned_pos_embd", -1);
  50. auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
  51. return ggml_rope_multi(
  52. ctx0, cur, positions, nullptr,
  53. d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
  54. 32768, hparams.rope_theta, 1, 0, 1, 32, 1);
  55. };
  56. ggml_tensor * cur = build_vit(
  57. inp, n_patches,
  58. norm_t,
  59. hparams.ffn_op,
  60. learned_pos_embd,
  61. add_pos);
  62. cb(cur, "vit_out", -1);
  63. // cb(ggml_sum(ctx0, cur), "vit_out_sum", -1);
  64. // GLM4V projector
  65. // ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130
  66. // patch merger (downsample)
  67. {
  68. int n_merge = hparams.n_merge;
  69. GGML_ASSERT(n_merge > 0);
  70. int n_token_out = n_patches / n_merge / n_merge;
  71. cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out);
  72. cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out]
  73. cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1);
  74. cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out]
  75. cur = ggml_add(ctx0, cur, model.mm_patch_merger_b);
  76. }
  77. // FC projector
  78. {
  79. cur = ggml_mul_mat(ctx0, model.projection, cur);
  80. // default LayerNorm (post_projection_norm)
  81. cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
  82. cur = ggml_gelu_erf(ctx0, cur);
  83. cb(cur, "after_fc_proj", -1);
  84. }
  85. // FFN projector
  86. {
  87. cur = build_ffn(cur,
  88. model.mm_ffn_up_w, model.mm_ffn_up_b,
  89. model.mm_ffn_gate_w, model.mm_ffn_gate_b,
  90. model.mm_ffn_down_w, model.mm_ffn_down_b,
  91. hparams.ffn_op, -1);
  92. cb(cur, "after_ffn_proj", -1);
  93. // cb(ggml_sum(ctx0, cur), "merged_sum", -1);
  94. }
  95. // build the graph
  96. ggml_build_forward_expand(gf, cur);
  97. return gf;
  98. }