models.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. #pragma once
  2. #include "../llama-model.h"
  3. #include "../llama-graph.h"
  4. // TODO: remove in follow-up PR - move to .cpp files
  5. #include "../llama-memory-recurrent.h"
  6. #include <cmath>
  7. struct llm_graph_context_mamba : public llm_graph_context {
  8. llm_graph_context_mamba(const llm_graph_params & params);
  9. virtual ~llm_graph_context_mamba() = default;
  10. ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
  11. ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
  12. };
  13. // Base class for RWKV-related models
  14. struct llm_build_rwkv6_base : public llm_graph_context {
  15. const llama_model & model;
  16. llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params);
  17. virtual ~llm_build_rwkv6_base() = default;
  18. ggml_tensor * build_rwkv6_channel_mix(const llama_layer * layer,
  19. ggml_tensor * cur,
  20. ggml_tensor * x_prev,
  21. llm_arch arch) const;
  22. ggml_tensor * build_rwkv6_time_mix(llm_graph_input_rs * inp,
  23. ggml_tensor * cur,
  24. ggml_tensor * x_prev,
  25. const llama_ubatch & ubatch,
  26. int il) const;
  27. };
  28. // Base class for RWKV7-related models
  29. struct llm_build_rwkv7_base : public llm_graph_context {
  30. const llama_model & model;
  31. llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params);
  32. virtual ~llm_build_rwkv7_base() = default;
  33. // RWKV7-specific graph building methods
  34. ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
  35. ggml_tensor * cur,
  36. ggml_tensor * x_prev,
  37. llm_arch arch) const;
  38. ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
  39. ggml_tensor * cur,
  40. ggml_tensor * x_prev,
  41. ggml_tensor *& first_layer_value,
  42. const llama_ubatch & ubatch,
  43. int il) const;
  44. };
  45. struct llm_build_afmoe : public llm_graph_context {
  46. llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
  47. };
  48. struct llm_build_apertus : public llm_graph_context {
  49. llm_build_apertus(const llama_model & model, const llm_graph_params & params);
  50. };
  51. struct llm_build_arcee : public llm_graph_context {
  52. llm_build_arcee(const llama_model & model, const llm_graph_params & params);
  53. };
  54. struct llm_build_arctic : public llm_graph_context {
  55. llm_build_arctic(const llama_model & model, const llm_graph_params & params);
  56. };
  57. struct llm_build_arwkv7 : public llm_build_rwkv7_base {
  58. llm_build_arwkv7(const llama_model & model, const llm_graph_params & params);
  59. };
  60. struct llm_build_baichuan : public llm_graph_context {
  61. llm_build_baichuan(const llama_model & model, const llm_graph_params & params);
  62. };
  63. struct llm_build_bailingmoe2 : public llm_graph_context {
  64. llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params);
  65. };
  66. struct llm_build_bailingmoe : public llm_graph_context {
  67. llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params);
  68. };
  69. struct llm_build_bert : public llm_graph_context {
  70. llm_build_bert(const llama_model & model, const llm_graph_params & params);
  71. };
  72. struct llm_build_bitnet : public llm_graph_context {
  73. llm_build_bitnet(const llama_model & model, const llm_graph_params & params);
  74. };
  75. struct llm_build_bloom : public llm_graph_context {
  76. llm_build_bloom(const llama_model & model, const llm_graph_params & params);
  77. };
  78. struct llm_build_chameleon : public llm_graph_context {
  79. llm_build_chameleon(const llama_model & model, const llm_graph_params & params);
  80. };
  81. struct llm_build_chatglm : public llm_graph_context {
  82. llm_build_chatglm(const llama_model & model, const llm_graph_params & params);
  83. };
  84. struct llm_build_codeshell : public llm_graph_context {
  85. llm_build_codeshell(const llama_model & model, const llm_graph_params & params);
  86. };
  87. struct llm_build_cogvlm : public llm_graph_context {
  88. llm_build_cogvlm(const llama_model & model, const llm_graph_params & params);
  89. };
  90. struct llm_build_cohere2_iswa : public llm_graph_context {
  91. llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params);
  92. };
  93. struct llm_build_command_r : public llm_graph_context {
  94. llm_build_command_r(const llama_model & model, const llm_graph_params & params);
  95. };
  96. struct llm_build_dbrx : public llm_graph_context {
  97. llm_build_dbrx(const llama_model & model, const llm_graph_params & params);
  98. };
  99. struct llm_build_deci : public llm_graph_context {
  100. llm_build_deci(const llama_model & model, const llm_graph_params & params);
  101. };
  102. struct llm_build_deepseek2 : public llm_graph_context {
  103. llm_build_deepseek2(const llama_model & model, const llm_graph_params & params);
  104. };
  105. struct llm_build_deepseek : public llm_graph_context {
  106. llm_build_deepseek(const llama_model & model, const llm_graph_params & params);
  107. };
  108. struct llm_build_dots1 : public llm_graph_context {
  109. llm_build_dots1(const llama_model & model, const llm_graph_params & params);
  110. };
  111. struct llm_build_dream : public llm_graph_context {
  112. llm_build_dream(const llama_model & model, const llm_graph_params & params);
  113. };
  114. struct llm_build_ernie4_5 : public llm_graph_context {
  115. llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params);
  116. };
  117. struct llm_build_ernie4_5_moe : public llm_graph_context {
  118. llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
  119. };
  120. template <bool iswa>
  121. struct llm_build_exaone4 : public llm_graph_context {
  122. llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
  123. };
  124. struct llm_build_exaone : public llm_graph_context {
  125. llm_build_exaone(const llama_model & model, const llm_graph_params & params);
  126. };
  127. struct llm_build_falcon : public llm_graph_context {
  128. llm_build_falcon(const llama_model & model, const llm_graph_params & params);
  129. };
  130. struct llm_build_falcon_h1 : public llm_graph_context_mamba {
  131. llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
  132. };
  133. struct llm_build_gemma2_iswa : public llm_graph_context {
  134. llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
  135. };
  136. template <bool iswa>
  137. struct llm_build_gemma3 : public llm_graph_context {
  138. llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
  139. };
  140. struct llm_build_gemma3n_iswa : public llm_graph_context {
  141. const llama_model & model;
  142. const int64_t n_embd_head;
  143. const int64_t n_embd_altup;
  144. const int64_t n_altup;
  145. const int i_altup_act;
  146. const int n_layer_sparsity = 10; // number of layers using activation sparsity
  147. const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
  148. llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
  149. ggml_tensor * calc_magnitude(ggml_tensor * x);
  150. ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
  151. ggml_tensor * get_per_layer_inputs();
  152. ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
  153. ggml_tensor * gaussian_topk(ggml_tensor * x);
  154. ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
  155. ggml_tensor * altup_predict(ggml_tensor * cur, int il);
  156. ggml_tensor * laurel(ggml_tensor * cur, int il);
  157. ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il);
  158. };
  159. struct llm_build_gemma_embedding : public llm_graph_context {
  160. llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params);
  161. };
  162. struct llm_build_gemma : public llm_graph_context {
  163. llm_build_gemma(const llama_model & model, const llm_graph_params & params);
  164. };
  165. struct llm_build_glm4 : public llm_graph_context {
  166. llm_build_glm4(const llama_model & model, const llm_graph_params & params);
  167. };
  168. struct llm_build_glm4_moe : public llm_graph_context {
  169. llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
  170. };
  171. struct llm_build_gpt2 : public llm_graph_context {
  172. llm_build_gpt2(const llama_model & model, const llm_graph_params & params);
  173. };
  174. struct llm_build_gptneox : public llm_graph_context {
  175. llm_build_gptneox(const llama_model & model, const llm_graph_params & params);
  176. };
  177. struct llm_build_granite : public llm_graph_context {
  178. llm_build_granite(const llama_model & model, const llm_graph_params & params);
  179. private:
  180. ggml_tensor * build_attention_layer(
  181. ggml_tensor * cur,
  182. ggml_tensor * inp_pos,
  183. llm_graph_input_attn_kv * inp_attn,
  184. const llama_model & model,
  185. const int64_t n_embd_head,
  186. const int il);
  187. ggml_tensor * build_layer_ffn(
  188. ggml_tensor * cur,
  189. ggml_tensor * inpSA,
  190. const llama_model & model,
  191. const int il);
  192. };
  193. struct llm_build_granite_hybrid : public llm_graph_context_mamba {
  194. llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
  195. ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
  196. ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
  197. const llama_model & model,const int64_t n_embd_head, const int il);
  198. };
  199. struct llm_build_grok : public llm_graph_context {
  200. llm_build_grok(const llama_model & model, const llm_graph_params & params);
  201. };
  202. struct llm_build_grovemoe : public llm_graph_context {
  203. llm_build_grovemoe(const llama_model & model, const llm_graph_params & params);
  204. };
  205. struct llm_build_hunyuan_dense : public llm_graph_context {
  206. llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params);
  207. };
  208. struct llm_build_hunyuan_moe : public llm_graph_context {
  209. llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params);
  210. };
  211. struct llm_build_internlm2 : public llm_graph_context {
  212. llm_build_internlm2(const llama_model & model, const llm_graph_params & params);
  213. };
  214. struct llm_build_jais : public llm_graph_context {
  215. llm_build_jais(const llama_model & model, const llm_graph_params & params);
  216. };
  217. struct llm_build_jamba : public llm_graph_context_mamba {
  218. llm_build_jamba(const llama_model & model, const llm_graph_params & params);
  219. };
  220. struct llm_build_lfm2 : public llm_graph_context {
  221. const llama_model & model;
  222. llm_build_lfm2(const llama_model & model, const llm_graph_params & params);
  223. ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const;
  224. ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const;
  225. ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const;
  226. ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il);
  227. };
  228. struct llm_build_llada : public llm_graph_context {
  229. llm_build_llada(const llama_model & model, const llm_graph_params & params);
  230. };
  231. struct llm_build_llada_moe : public llm_graph_context {
  232. llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
  233. };
  234. template <bool embed>
  235. struct llm_build_llama : public llm_graph_context {
  236. llm_build_llama(const llama_model & model, const llm_graph_params & params);
  237. };
  238. struct llm_build_llama_iswa : public llm_graph_context {
  239. llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
  240. };
  241. struct llm_build_mamba : public llm_graph_context_mamba {
  242. llm_build_mamba(const llama_model & model, const llm_graph_params & params);
  243. };
  244. struct llm_build_mimo2_iswa : public llm_graph_context {
  245. llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params);
  246. };
  247. struct llm_build_minicpm3 : public llm_graph_context {
  248. llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
  249. };
  250. struct llm_build_minimax_m2 : public llm_graph_context {
  251. llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
  252. };
  253. struct llm_build_mistral3 : public llm_graph_context {
  254. llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
  255. };
  256. template <bool iswa>
  257. struct llm_build_modern_bert : public llm_graph_context {
  258. llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
  259. };
  260. struct llm_build_mpt : public llm_graph_context {
  261. llm_build_mpt(const llama_model & model, const llm_graph_params & params);
  262. };
  263. struct llm_build_nemotron : public llm_graph_context {
  264. llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
  265. };
  266. struct llm_build_nemotron_h : public llm_graph_context_mamba {
  267. llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
  268. ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
  269. ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
  270. const llama_model & model, const int64_t n_embd_head, const int il);
  271. };
  272. struct llm_build_neo_bert : public llm_graph_context {
  273. llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
  274. };
  275. template <bool iswa>
  276. struct llm_build_olmo2 : public llm_graph_context {
  277. llm_build_olmo2(const llama_model & model, const llm_graph_params & params);
  278. };
  279. struct llm_build_olmoe : public llm_graph_context {
  280. llm_build_olmoe(const llama_model & model, const llm_graph_params & params);
  281. };
  282. struct llm_build_olmo : public llm_graph_context {
  283. llm_build_olmo(const llama_model & model, const llm_graph_params & params);
  284. };
  285. struct llm_build_openai_moe_iswa : public llm_graph_context {
  286. llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params);
  287. };
  288. struct llm_build_openelm : public llm_graph_context {
  289. llm_build_openelm(const llama_model & model, const llm_graph_params & params);
  290. };
  291. struct llm_build_orion : public llm_graph_context {
  292. llm_build_orion(const llama_model & model, const llm_graph_params & params);
  293. };
  294. struct llm_build_pangu_embedded : public llm_graph_context {
  295. llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
  296. };
  297. struct llm_build_phi2 : public llm_graph_context {
  298. llm_build_phi2(const llama_model & model, const llm_graph_params & params);
  299. };
  300. template<bool iswa>
  301. struct llm_build_phi3 : public llm_graph_context {
  302. llm_build_phi3(const llama_model & model, const llm_graph_params & params);
  303. };
  304. struct llm_build_plamo2 : public llm_graph_context_mamba {
  305. llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
  306. private:
  307. ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
  308. ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur,
  309. const llama_model & model, int il);
  310. };
  311. struct llm_build_plamo : public llm_graph_context {
  312. llm_build_plamo(const llama_model & model, const llm_graph_params & params);
  313. };
  314. template <bool iswa>
  315. struct llm_build_plamo3 : public llm_graph_context {
  316. llm_build_plamo3(const llama_model & model, const llm_graph_params & params);
  317. };
  318. struct llm_build_plm : public llm_graph_context {
  319. llm_build_plm(const llama_model & model, const llm_graph_params & params);
  320. };
  321. struct llm_build_qwen2 : public llm_graph_context {
  322. llm_build_qwen2(const llama_model & model, const llm_graph_params & params);
  323. };
  324. struct llm_build_qwen2moe : public llm_graph_context {
  325. llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params);
  326. };
  327. struct llm_build_qwen2vl : public llm_graph_context {
  328. llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params);
  329. };
  330. struct llm_build_qwen3 : public llm_graph_context {
  331. llm_build_qwen3(const llama_model & model, const llm_graph_params & params);
  332. };
  333. struct llm_build_qwen3moe : public llm_graph_context {
  334. llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params);
  335. };
  336. struct llm_build_qwen3vl : public llm_graph_context {
  337. llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params);
  338. };
  339. struct llm_build_qwen3vlmoe : public llm_graph_context {
  340. llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
  341. };
  342. struct llm_build_qwen3next : public llm_graph_context_mamba {
  343. llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
  344. private:
  345. ggml_tensor * build_layer_attn(
  346. llm_graph_input_attn_kv * inp_attn,
  347. ggml_tensor * cur,
  348. ggml_tensor * inp_pos,
  349. int il);
  350. ggml_tensor * build_layer_attn_linear(
  351. llm_graph_input_rs * inp,
  352. ggml_tensor * cur,
  353. ggml_tensor * causal_mask,
  354. ggml_tensor * identity,
  355. ggml_tensor * diag_mask,
  356. int il);
  357. ggml_tensor * build_layer_ffn(
  358. ggml_tensor * cur,
  359. int il);
  360. ggml_tensor * build_delta_net_chunking(
  361. ggml_tensor * q,
  362. ggml_tensor * k,
  363. ggml_tensor * v,
  364. ggml_tensor * g,
  365. ggml_tensor * beta,
  366. ggml_tensor * state,
  367. ggml_tensor * causal_mask,
  368. ggml_tensor * identity,
  369. ggml_tensor * diag_mask,
  370. int il);
  371. ggml_tensor * build_delta_net_autoregressive(
  372. ggml_tensor * q,
  373. ggml_tensor * k,
  374. ggml_tensor * v,
  375. ggml_tensor * g,
  376. ggml_tensor * beta,
  377. ggml_tensor * state,
  378. int il);
  379. ggml_tensor * build_norm_gated(
  380. ggml_tensor * input,
  381. ggml_tensor * weights,
  382. ggml_tensor * gate,
  383. int layer);
  384. const llama_model & model;
  385. };
  386. struct llm_build_qwen : public llm_graph_context {
  387. llm_build_qwen(const llama_model & model, const llm_graph_params & params);
  388. };
  389. struct llm_build_refact : public llm_graph_context {
  390. llm_build_refact(const llama_model & model, const llm_graph_params & params);
  391. };
  392. struct llm_build_rnd1 : public llm_graph_context {
  393. llm_build_rnd1(const llama_model & model, const llm_graph_params & params);
  394. };
  395. struct llm_build_rwkv6 : public llm_build_rwkv6_base {
  396. llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
  397. };
  398. struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
  399. llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params);
  400. };
  401. struct llm_build_rwkv7 : public llm_build_rwkv7_base {
  402. llm_build_rwkv7(const llama_model & model, const llm_graph_params & params);
  403. };
  404. struct llm_build_seed_oss : public llm_graph_context {
  405. llm_build_seed_oss(const llama_model & model, const llm_graph_params & params);
  406. };
  407. template <bool iswa>
  408. struct llm_build_smallthinker : public llm_graph_context {
  409. llm_build_smallthinker(const llama_model & model, const llm_graph_params & params);
  410. };
  411. struct llm_build_smollm3 : public llm_graph_context {
  412. llm_build_smollm3(const llama_model & model, const llm_graph_params & params);
  413. };
  414. struct llm_build_stablelm : public llm_graph_context {
  415. llm_build_stablelm(const llama_model & model, const llm_graph_params & params);
  416. };
  417. struct llm_build_starcoder2 : public llm_graph_context {
  418. llm_build_starcoder2(const llama_model & model, const llm_graph_params & params);
  419. };
  420. struct llm_build_starcoder : public llm_graph_context {
  421. llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
  422. };
  423. struct llm_build_t5_dec : public llm_graph_context {
  424. llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
  425. };
  426. struct llm_build_t5_enc : public llm_graph_context {
  427. llm_build_t5_enc(const llama_model & model, const llm_graph_params & params);
  428. };
  429. struct llm_build_wavtokenizer_dec : public llm_graph_context {
  430. llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params);
  431. };
  432. struct llm_build_xverse : public llm_graph_context {
  433. llm_build_xverse(const llama_model & model, const llm_graph_params & params);
  434. };