1
0

internvl.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_internvl::build() {
  3. GGML_ASSERT(model.class_embedding != nullptr);
  4. GGML_ASSERT(model.position_embeddings != nullptr);
  5. const int n_pos = n_patches + 1;
  6. ggml_tensor * inp = build_inp();
  7. // add CLS token
  8. inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
  9. // The larger models use a different ViT, which uses RMS norm instead of layer norm
  10. // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188
  11. norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45)
  12. ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B)
  13. : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models)
  14. ggml_tensor * cur = build_vit(
  15. inp, n_pos,
  16. norm_t,
  17. hparams.ffn_op,
  18. model.position_embeddings,
  19. nullptr);
  20. // remove CLS token
  21. cur = ggml_view_2d(ctx0, cur,
  22. n_embd, n_patches,
  23. ggml_row_size(cur->type, n_embd), 0);
  24. // pixel shuffle
  25. {
  26. const int scale_factor = model.hparams.n_merge;
  27. const int bsz = 1; // batch size, always 1 for now since we don't support batching
  28. const int height = n_patches_y;
  29. const int width = n_patches_x;
  30. GGML_ASSERT(scale_factor > 0);
  31. cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
  32. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  33. cur = ggml_cont_4d(ctx0, cur,
  34. n_embd * scale_factor * scale_factor,
  35. height / scale_factor,
  36. width / scale_factor,
  37. bsz);
  38. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  39. // flatten to 2D
  40. cur = ggml_cont_2d(ctx0, cur,
  41. n_embd * scale_factor * scale_factor,
  42. cur->ne[1] * cur->ne[2]);
  43. }
  44. // projector (always using GELU activation)
  45. {
  46. // projector LayerNorm uses pytorch's default eps = 1e-5
  47. // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
  48. cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
  49. cur = build_ffn(cur,
  50. model.mm_1_w, model.mm_1_b,
  51. nullptr, nullptr,
  52. model.mm_3_w, model.mm_3_b,
  53. FFN_GELU,
  54. -1);
  55. }
  56. // build the graph
  57. ggml_build_forward_expand(gf, cur);
  58. return gf;
  59. }