youtuvl.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #include "models.h"
  2. ggml_cgraph * clip_graph_youtuvl::build() {
  3. GGML_ASSERT(model.class_embedding == nullptr);
  4. const int batch_size = 1;
  5. const bool use_window_attn = !hparams.wa_layer_indexes.empty();
  6. const int n_pos = n_patches;
  7. const int num_position_ids = n_pos * 4;
  8. const int m = 2;
  9. const int Wp = n_patches_x;
  10. const int Hp = n_patches_y;
  11. const int Hm = Hp / m;
  12. const int Wm = Wp / m;
  13. norm_type norm_t = NORM_TYPE_NORMAL;
  14. int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
  15. ggml_tensor * inp = build_inp_raw();
  16. // change conv3d to linear
  17. // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
  18. {
  19. inp = ggml_reshape_4d(
  20. ctx0, inp,
  21. Wm * m * patch_size, m * patch_size, Hm, 3);
  22. inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
  23. inp = ggml_cont_4d(
  24. ctx0, inp,
  25. m * patch_size * 3, Wm, m * patch_size, Hm);
  26. inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
  27. inp = ggml_cont_4d(
  28. ctx0, inp,
  29. m * patch_size * 3, patch_size, m, Hm * Wm);
  30. inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
  31. inp = ggml_cont_4d(
  32. ctx0, inp,
  33. patch_size, 3, patch_size, Hm * Wm * m * m);
  34. inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
  35. inp = ggml_cont_3d(
  36. ctx0, inp,
  37. 3*patch_size* patch_size, Hm * Wm * m * m, 1);
  38. }
  39. inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
  40. if (model.patch_bias) {
  41. inp = ggml_add(ctx0, inp, model.patch_bias);
  42. }
  43. inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
  44. ggml_tensor * inpL = inp;
  45. ggml_tensor * window_mask = nullptr;
  46. ggml_tensor * window_idx = nullptr;
  47. ggml_tensor * inv_window_idx = nullptr;
  48. ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
  49. ggml_set_name(positions, "positions");
  50. ggml_set_input(positions);
  51. // pre-layernorm
  52. if (model.pre_ln_w) {
  53. inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
  54. }
  55. if (use_window_attn) {
  56. inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
  57. ggml_set_name(inv_window_idx, "inv_window_idx");
  58. ggml_set_input(inv_window_idx);
  59. // mask for window attention
  60. window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
  61. ggml_set_name(window_mask, "window_mask");
  62. ggml_set_input(window_mask);
  63. // if flash attn is used, we need to pad the mask and cast to f16
  64. if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
  65. window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
  66. }
  67. // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
  68. GGML_ASSERT(batch_size == 1);
  69. inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
  70. inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
  71. inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
  72. }
  73. // loop over layers
  74. for (int il = 0; il < n_layer; il++) {
  75. const auto & layer = model.layers[il];
  76. const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
  77. ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
  78. // layernorm1
  79. cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
  80. // self-attention
  81. {
  82. ggml_tensor * Qcur = ggml_add(ctx0,
  83. ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
  84. ggml_tensor * Kcur = ggml_add(ctx0,
  85. ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
  86. ggml_tensor * Vcur = ggml_add(ctx0,
  87. ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
  88. Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
  89. Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
  90. Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
  91. Qcur = ggml_rope_multi(
  92. ctx0, Qcur, positions, nullptr,
  93. d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
  94. Kcur = ggml_rope_multi(
  95. ctx0, Kcur, positions, nullptr,
  96. d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
  97. ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
  98. cur = build_attn(layer.o_w, layer.o_b,
  99. Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
  100. }
  101. // re-add the layer input, e.g., residual
  102. cur = ggml_add(ctx0, cur, inpL);
  103. inpL = cur; // inpL = residual, cur = hidden_states
  104. // layernorm2
  105. cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
  106. // ffn
  107. cur = build_ffn(cur,
  108. layer.ff_up_w, layer.ff_up_b,
  109. nullptr, nullptr,
  110. layer.ff_down_w, layer.ff_down_b,
  111. hparams.ffn_op, il);
  112. // residual 2
  113. cur = ggml_add(ctx0, inpL, cur);
  114. inpL = cur;
  115. }
  116. ggml_tensor * embeddings = inpL;
  117. if (use_window_attn) {
  118. const int spatial_merge_unit = 4;
  119. window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
  120. ggml_set_name(window_idx, "window_idx");
  121. ggml_set_input(window_idx);
  122. GGML_ASSERT(batch_size == 1);
  123. embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
  124. embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
  125. embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
  126. cb(embeddings, "window_order_restored", -1);
  127. }
  128. // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
  129. if (model.post_ln_w) {
  130. embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
  131. }
  132. // Now apply merger (VLPatchMerger):
  133. // 1. Apply RMS norm (ln_q in VLPatchMerger)
  134. embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
  135. cb(embeddings, "merger_normed", -1);
  136. // 2. First reshape for spatial merge (merge 2x2 patches)
  137. embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
  138. cb(embeddings, "merger_reshaped", -1);
  139. embeddings = build_ffn(embeddings,
  140. model.mm_0_w, model.mm_0_b,
  141. nullptr, nullptr,
  142. model.mm_1_w, model.mm_1_b,
  143. FFN_GELU,
  144. -1);
  145. ggml_build_forward_expand(gf, embeddings);
  146. return gf;
  147. }