1
0

llava.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #include "models.h"
  2. // this graph is used by llava, granite and glm
  3. // due to having embedding_stack (used by granite), we cannot reuse build_vit
  4. ggml_cgraph * clip_graph_llava::build() {
  5. const int batch_size = 1;
  6. const int n_pos = n_patches + (model.class_embedding ? 1 : 0);
  7. GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
  8. // Calculate the deepest feature layer based on hparams and projector type
  9. int max_feature_layer = n_layer;
  10. {
  11. // Get the index of the second to last layer; this is the default for models that have a llava projector
  12. int il_last = hparams.n_layer - 1;
  13. int deepest_feature_layer = -1;
  14. if (proj_type == PROJECTOR_TYPE_MINICPMV || proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  15. il_last += 1;
  16. }
  17. // If we set explicit vision feature layers, only go up to the deepest one
  18. // NOTE: only used by granite-vision models for now
  19. for (const auto & feature_layer : hparams.vision_feature_layer) {
  20. if (feature_layer > deepest_feature_layer) {
  21. deepest_feature_layer = feature_layer;
  22. }
  23. }
  24. max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer;
  25. }
  26. ggml_tensor * inp = build_inp();
  27. // concat class_embeddings and patch_embeddings
  28. if (model.class_embedding) {
  29. inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
  30. }
  31. ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
  32. ggml_set_name(positions, "positions");
  33. ggml_set_input(positions);
  34. inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
  35. ggml_tensor * inpL = inp;
  36. // pre-layernorm
  37. if (model.pre_ln_w) {
  38. inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
  39. cb(inpL, "pre_ln", -1);
  40. }
  41. std::vector<ggml_tensor *> embedding_stack;
  42. const auto & vision_feature_layer = hparams.vision_feature_layer;
  43. // loop over layers
  44. for (int il = 0; il < max_feature_layer; il++) {
  45. auto & layer = model.layers[il];
  46. ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
  47. // If this is an embedding feature layer, save the output.
  48. // NOTE: 0 index here refers to the input to the encoder.
  49. if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
  50. embedding_stack.push_back(cur);
  51. }
  52. // layernorm1
  53. cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
  54. cb(cur, "layer_inp_normed", il);
  55. // self-attention
  56. {
  57. ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
  58. if (layer.q_b) {
  59. Qcur = ggml_add(ctx0, Qcur, layer.q_b);
  60. }
  61. ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
  62. if (layer.k_b) {
  63. Kcur = ggml_add(ctx0, Kcur, layer.k_b);
  64. }
  65. ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
  66. if (layer.v_b) {
  67. Vcur = ggml_add(ctx0, Vcur, layer.v_b);
  68. }
  69. Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
  70. Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
  71. Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
  72. cb(Qcur, "Qcur", il);
  73. cb(Kcur, "Kcur", il);
  74. cb(Vcur, "Vcur", il);
  75. cur = build_attn(layer.o_w, layer.o_b,
  76. Qcur, Kcur, Vcur, nullptr, kq_scale, il);
  77. cb(cur, "attn_out", il);
  78. }
  79. // re-add the layer input, e.g., residual
  80. cur = ggml_add(ctx0, cur, inpL);
  81. inpL = cur; // inpL = residual, cur = hidden_states
  82. cb(cur, "ffn_inp", il);
  83. // layernorm2
  84. cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
  85. cb(cur, "ffn_inp_normed", il);
  86. // ffn
  87. cur = build_ffn(cur,
  88. layer.ff_up_w, layer.ff_up_b,
  89. layer.ff_gate_w, layer.ff_gate_b,
  90. layer.ff_down_w, layer.ff_down_b,
  91. hparams.ffn_op, il);
  92. cb(cur, "ffn_out", il);
  93. // residual 2
  94. cur = ggml_add(ctx0, inpL, cur);
  95. cb(cur, "layer_out", il);
  96. inpL = cur;
  97. }
  98. // post-layernorm
  99. if (model.post_ln_w) {
  100. inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
  101. }
  102. ggml_tensor * embeddings = inpL;
  103. // process vision feature layers (used by granite)
  104. {
  105. // final layer is a vision feature layer
  106. if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) {
  107. embedding_stack.push_back(inpL);
  108. }
  109. // If feature layers are explicitly set, stack them (if we have multiple)
  110. if (!embedding_stack.empty()) {
  111. embeddings = embedding_stack[0];
  112. for (size_t i = 1; i < embedding_stack.size(); i++) {
  113. embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
  114. }
  115. }
  116. }
  117. // llava projector (also used by granite)
  118. if (hparams.has_llava_projector) {
  119. embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
  120. ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
  121. ggml_set_name(patches, "patches");
  122. ggml_set_input(patches);
  123. // shape [1, 576, 1024]
  124. // ne is whcn, ne = [1024, 576, 1, 1]
  125. embeddings = ggml_get_rows(ctx0, embeddings, patches);
  126. // print_tensor_info(embeddings, "embeddings");
  127. // llava projector
  128. if (proj_type == PROJECTOR_TYPE_MLP) {
  129. embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
  130. embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
  131. embeddings = ggml_gelu(ctx0, embeddings);
  132. if (model.mm_2_w) {
  133. embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
  134. embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
  135. }
  136. }
  137. else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
  138. embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
  139. embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
  140. // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
  141. // First LayerNorm
  142. embeddings = ggml_norm(ctx0, embeddings, eps);
  143. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
  144. model.mm_1_b);
  145. // GELU activation
  146. embeddings = ggml_gelu(ctx0, embeddings);
  147. // Second linear layer
  148. embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
  149. embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
  150. // Second LayerNorm
  151. embeddings = ggml_norm(ctx0, embeddings, eps);
  152. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
  153. model.mm_4_b);
  154. }
  155. else if (proj_type == PROJECTOR_TYPE_LDP) {
  156. // MobileVLM projector
  157. int n_patch = 24;
  158. ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
  159. mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
  160. mlp_1 = ggml_gelu(ctx0, mlp_1);
  161. ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
  162. mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
  163. // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
  164. // block 1
  165. ggml_tensor * block_1 = nullptr;
  166. {
  167. // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
  168. mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3);
  169. mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
  170. // stride = 1, padding = 1, bias is nullptr
  171. block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
  172. // layer norm
  173. // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  174. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
  175. // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
  176. block_1 = ggml_norm(ctx0, block_1, eps);
  177. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
  178. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  179. // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  180. // hardswish
  181. ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
  182. block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
  183. // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  184. // pointwise conv
  185. block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
  186. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
  187. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
  188. block_1 = ggml_relu(ctx0, block_1);
  189. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
  190. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
  191. block_1 = ggml_hardsigmoid(ctx0, block_1);
  192. // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
  193. block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
  194. block_1 = ggml_mul(ctx0, block_1_hw, block_1);
  195. int w = block_1->ne[0], h = block_1->ne[1];
  196. block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
  197. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
  198. // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
  199. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
  200. block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
  201. // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
  202. block_1 = ggml_norm(ctx0, block_1, eps);
  203. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
  204. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  205. // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  206. // residual
  207. block_1 = ggml_add(ctx0, mlp_3, block_1);
  208. }
  209. // block_2
  210. {
  211. // stride = 2
  212. block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
  213. // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
  214. // layer norm
  215. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
  216. // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
  217. block_1 = ggml_norm(ctx0, block_1, eps);
  218. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
  219. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  220. // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
  221. // hardswish
  222. ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
  223. // not sure the parameters is right for globalAvgPooling
  224. block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
  225. // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  226. // pointwise conv
  227. block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
  228. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
  229. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
  230. block_1 = ggml_relu(ctx0, block_1);
  231. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
  232. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
  233. block_1 = ggml_hardsigmoid(ctx0, block_1);
  234. // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  235. block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
  236. block_1 = ggml_mul(ctx0, block_1_hw, block_1);
  237. int w = block_1->ne[0], h = block_1->ne[1];
  238. block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
  239. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
  240. // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
  241. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
  242. block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
  243. // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
  244. block_1 = ggml_norm(ctx0, block_1, eps);
  245. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
  246. block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
  247. // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
  248. }
  249. embeddings = block_1;
  250. }
  251. else if (proj_type == PROJECTOR_TYPE_LDPV2)
  252. {
  253. int n_patch = 24;
  254. ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
  255. mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
  256. mlp_0 = ggml_gelu(ctx0, mlp_0);
  257. ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
  258. mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
  259. // mlp_2 ne = [2048, 576, 1, 1]
  260. // // AVG Pool Layer 2*2, strides = 2
  261. mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3);
  262. // mlp_2 ne = [576, 2048, 1, 1]
  263. mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
  264. // mlp_2 ne [24, 24, 2048, 1]
  265. mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
  266. // weight ne = [3, 3, 2048, 1]
  267. ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
  268. peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
  269. peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
  270. mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
  271. peg_0 = ggml_add(ctx0, peg_0, mlp_2);
  272. peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
  273. embeddings = peg_0;
  274. }
  275. else {
  276. GGML_ABORT("fatal error");
  277. }
  278. }
  279. // glm projector
  280. else if (proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  281. size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
  282. embeddings = ggml_permute(ctx0,embeddings,1,0,2,3);
  283. embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
  284. embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
  285. embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
  286. embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
  287. embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
  288. // GLU
  289. {
  290. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
  291. embeddings = ggml_norm(ctx0, embeddings, eps);
  292. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
  293. embeddings = ggml_gelu_inplace(ctx0, embeddings);
  294. ggml_tensor * x = embeddings;
  295. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
  296. x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
  297. embeddings = ggml_swiglu_split(ctx0, embeddings, x);
  298. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
  299. }
  300. // arrangement of BOI/EOI token embeddings
  301. // note: these embeddings are not present in text model, hence we cannot process them as text tokens
  302. // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
  303. {
  304. embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI
  305. embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI
  306. }
  307. }
  308. else {
  309. GGML_ABORT("llava: unknown projector type");
  310. }
  311. // build the graph
  312. ggml_build_forward_expand(gf, embeddings);
  313. return gf;
  314. }