llama-graph.h 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-batch.h"
  4. #include "llama-hparams.h"
  5. #include "llama-adapter.h"
  6. #include <cstdint>
  7. #include <vector>
  8. #include <memory>
  9. #include <set>
  10. #include <functional>
  11. struct ggml_cgraph;
  12. struct ggml_context;
  13. struct ggml_tensor;
  14. struct llama_cparams;
  15. struct llama_memory_context_i;
  16. class llama_kv_cache_context;
  17. class llama_kv_cache_iswa_context;
  18. class llama_memory_recurrent_context;
  19. class llama_memory_hybrid_context;
  20. // certain models (typically multi-modal) can produce different types of graphs
  21. enum llm_graph_type {
  22. LLM_GRAPH_TYPE_DEFAULT,
  23. LLM_GRAPH_TYPE_ENCODER,
  24. LLM_GRAPH_TYPE_DECODER,
  25. };
  26. enum llm_ffn_op_type {
  27. LLM_FFN_SILU,
  28. LLM_FFN_GELU,
  29. LLM_FFN_RELU,
  30. LLM_FFN_RELU_SQR,
  31. LLM_FFN_SWIGLU,
  32. LLM_FFN_GEGLU,
  33. LLM_FFN_REGLU,
  34. LLM_FFN_SWIGLU_OAI_MOE,
  35. };
  36. enum llm_ffn_gate_type {
  37. LLM_FFN_SEQ,
  38. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  39. };
  40. enum llm_norm_type {
  41. LLM_NORM,
  42. LLM_NORM_RMS,
  43. LLM_NORM_GROUP,
  44. };
  45. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  46. struct llama_cross {
  47. // the output embeddings from the encoder as a ggml tensor
  48. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  49. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  50. //ggml_tensor * t_embd = nullptr;
  51. int64_t n_embd = 0;
  52. int64_t n_enc = 0;
  53. // embeddings data copied to host memory (tmp)
  54. std::vector<float> v_embd;
  55. // needed to construct the cross-attention mask in the decoder
  56. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  57. };
  58. struct llm_graph_params;
  59. //
  60. // llm_graph_input
  61. //
  62. class llm_graph_input_i {
  63. public:
  64. virtual ~llm_graph_input_i() = default;
  65. virtual void set_input(const llama_ubatch * ubatch) = 0;
  66. // return true if the resulting input tensors using the provided graph parameters would be
  67. // the same as the previous input tensors that we have currently stored in the object
  68. virtual bool can_reuse(const llm_graph_params & params) {
  69. // returning false here by default will prevent from reusing the graph if the check
  70. // for the input type has not been implemented yet
  71. GGML_UNUSED(params);
  72. return false;
  73. }
  74. };
  75. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  76. class llm_graph_input_embd : public llm_graph_input_i {
  77. public:
  78. llm_graph_input_embd() = default;
  79. virtual ~llm_graph_input_embd() = default;
  80. void set_input(const llama_ubatch * ubatch) override;
  81. bool can_reuse(const llm_graph_params & params) override;
  82. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  83. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  84. };
  85. class llm_graph_input_pos : public llm_graph_input_i {
  86. public:
  87. llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  88. virtual ~llm_graph_input_pos() = default;
  89. void set_input(const llama_ubatch * ubatch) override;
  90. bool can_reuse(const llm_graph_params & params) override;
  91. ggml_tensor * pos = nullptr; // I32 [n_batch]
  92. const uint32_t n_pos_per_embd = 1;
  93. };
  94. // temperature tuning, used by llama4
  95. class llm_graph_input_attn_temp : public llm_graph_input_i {
  96. public:
  97. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  98. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  99. virtual ~llm_graph_input_attn_temp() = default;
  100. void set_input(const llama_ubatch * ubatch) override;
  101. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  102. const uint32_t n_attn_temp_floor_scale;
  103. const float f_attn_temp_scale;
  104. };
  105. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  106. public:
  107. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  108. virtual ~llm_graph_input_pos_bucket() = default;
  109. void set_input(const llama_ubatch * ubatch) override;
  110. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  111. const llama_hparams hparams;
  112. };
  113. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  114. public:
  115. llm_graph_input_pos_bucket_kv(
  116. const llama_hparams & hparams,
  117. const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
  118. virtual ~llm_graph_input_pos_bucket_kv() = default;
  119. void set_input(const llama_ubatch * ubatch) override;
  120. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  121. const llama_hparams hparams;
  122. const llama_kv_cache_context * mctx;
  123. };
  124. class llm_graph_input_out_ids : public llm_graph_input_i {
  125. public:
  126. llm_graph_input_out_ids(
  127. const llama_hparams & hparams,
  128. const llama_cparams & cparams,
  129. uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  130. virtual ~llm_graph_input_out_ids() = default;
  131. void set_input(const llama_ubatch * ubatch) override;
  132. bool can_reuse(const llm_graph_params & params) override;
  133. ggml_tensor * out_ids; // I32 [n_outputs]
  134. const llama_hparams hparams;
  135. const llama_cparams cparams;
  136. const uint32_t n_outputs;
  137. };
  138. class llm_graph_input_mean : public llm_graph_input_i {
  139. public:
  140. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  141. virtual ~llm_graph_input_mean() = default;
  142. void set_input(const llama_ubatch * ubatch) override;
  143. ggml_tensor * mean; // F32 [n_batch, n_batch]
  144. const llama_cparams cparams;
  145. };
  146. class llm_graph_input_cls : public llm_graph_input_i {
  147. public:
  148. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  149. virtual ~llm_graph_input_cls() = default;
  150. void set_input(const llama_ubatch * ubatch) override;
  151. ggml_tensor * cls; // I32 [n_batch]
  152. const llama_cparams cparams;
  153. };
  154. class llm_graph_input_rs : public llm_graph_input_i {
  155. public:
  156. llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
  157. virtual ~llm_graph_input_rs() = default;
  158. void set_input(const llama_ubatch * ubatch) override;
  159. ggml_tensor * s_copy; // I32 [n_rs]
  160. // views of s_copy, computed once per graph
  161. // and shared across layers which use build_rs
  162. ggml_tensor * s_copy_main; // I32 [n_seqs]
  163. ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
  164. const llama_memory_recurrent_context * mctx;
  165. };
  166. class llm_graph_input_cross_embd : public llm_graph_input_i {
  167. public:
  168. llm_graph_input_cross_embd(
  169. const llama_cross * cross) : cross(cross) {}
  170. virtual ~llm_graph_input_cross_embd() = default;
  171. void set_input(const llama_ubatch * ubatch) override;
  172. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  173. const llama_cross * cross;
  174. };
  175. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  176. public:
  177. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  178. hparams(hparams),
  179. cparams(cparams) {
  180. }
  181. ~llm_graph_input_attn_no_cache() = default;
  182. void set_input(const llama_ubatch * ubatch) override;
  183. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  184. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
  185. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
  186. const llama_hparams hparams;
  187. const llama_cparams cparams;
  188. };
  189. class llm_graph_input_attn_kv : public llm_graph_input_i {
  190. public:
  191. llm_graph_input_attn_kv(
  192. const llama_hparams & hparams,
  193. const llama_cparams & cparams,
  194. const llama_kv_cache_context * mctx) :
  195. hparams(hparams),
  196. cparams(cparams),
  197. mctx(mctx) {
  198. }
  199. ~llm_graph_input_attn_kv() = default;
  200. void set_input(const llama_ubatch * ubatch) override;
  201. bool can_reuse(const llm_graph_params & params) override;
  202. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  203. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  204. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  205. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  206. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  207. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  208. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  209. // note: these have to be copies because in order to be able to reuse a graph, its inputs
  210. // need to carry these parameters with them. otherwise, they can point to freed
  211. // llm_graph_params from a previous batch, causing stack-use-after-return
  212. const llama_hparams hparams;
  213. const llama_cparams cparams;
  214. const llama_kv_cache_context * mctx;
  215. };
  216. class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
  217. public:
  218. llm_graph_input_attn_kv_iswa(
  219. const llama_hparams & hparams,
  220. const llama_cparams & cparams,
  221. const llama_kv_cache_iswa_context * mctx) :
  222. hparams(hparams),
  223. cparams(cparams),
  224. mctx(mctx) {
  225. }
  226. ~llm_graph_input_attn_kv_iswa() = default;
  227. void set_input(const llama_ubatch * ubatch) override;
  228. bool can_reuse(const llm_graph_params & params) override;
  229. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  230. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  231. ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
  232. ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
  233. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  234. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  235. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  236. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  237. ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
  238. ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  239. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  240. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  241. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  242. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  243. const llama_hparams hparams;
  244. const llama_cparams cparams;
  245. const llama_kv_cache_iswa_context * mctx;
  246. };
  247. class llm_graph_input_attn_cross : public llm_graph_input_i {
  248. public:
  249. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  250. ~llm_graph_input_attn_cross() = default;
  251. void set_input(const llama_ubatch * ubatch) override;
  252. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  253. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  254. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  255. const llama_cross * cross = nullptr;
  256. };
  257. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  258. public:
  259. llm_graph_input_mem_hybrid(
  260. std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
  261. std::unique_ptr<llm_graph_input_rs> inp_rs,
  262. const llama_memory_hybrid_context * mctx) :
  263. inp_attn(std::move(inp_attn)),
  264. inp_rs(std::move(inp_rs)),
  265. mctx(mctx) { }
  266. virtual ~llm_graph_input_mem_hybrid() = default;
  267. void set_input(const llama_ubatch * ubatch) override;
  268. std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
  269. std::unique_ptr<llm_graph_input_rs> inp_rs;
  270. llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
  271. llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
  272. const llama_memory_hybrid_context * mctx;
  273. };
  274. //
  275. // llm_graph_result
  276. //
  277. // these objects deliver the result from the graph build process back to the llama_context
  278. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  279. // specific data, by calling the set_inputs() method
  280. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  281. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  282. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  283. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  284. class llm_graph_result;
  285. struct llm_graph_params {
  286. llm_arch arch = LLM_ARCH_UNKNOWN;
  287. llama_hparams hparams;
  288. llama_cparams cparams;
  289. llama_ubatch ubatch; // note: intentionally make a copy
  290. llm_graph_type gtype;
  291. ggml_backend_sched_t sched;
  292. ggml_backend_t backend_cpu;
  293. const llama_adapter_cvec * cvec;
  294. const llama_adapter_loras * loras;
  295. const llama_memory_context_i * mctx;
  296. const llama_cross * cross;
  297. uint32_t n_outputs;
  298. llm_graph_cb cb;
  299. llm_graph_result * res;
  300. // return true if the "other" params would result in a graph with the same topology as with the current params
  301. // having the same topology allows us to reuse the graph in some cases
  302. bool allow_reuse(const llm_graph_params & other) const {
  303. // first check the ubatch
  304. bool can_reuse_ubatch =
  305. ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
  306. ubatch.n_tokens == other.ubatch.n_tokens &&
  307. ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
  308. ubatch.n_seqs == other.ubatch.n_seqs &&
  309. ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
  310. (
  311. (!ubatch.token && !other.ubatch.token) ||
  312. (!ubatch.embd && !other.ubatch.embd)
  313. );
  314. // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
  315. // the reason is because the set of attention streams would be different for different sequences
  316. if (can_reuse_ubatch && ubatch.equal_seqs()) {
  317. if (!ubatch.data) {
  318. // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
  319. // therefore we cannot perform the sequence id check. normally should never happen
  320. can_reuse_ubatch = false;
  321. } else {
  322. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  323. can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
  324. }
  325. }
  326. }
  327. if (!can_reuse_ubatch) {
  328. return false;
  329. }
  330. return
  331. cparams.embeddings == other.cparams.embeddings &&
  332. cparams.causal_attn == other.cparams.causal_attn &&
  333. arch == other.arch &&
  334. gtype == other.gtype &&
  335. cvec == other.cvec &&
  336. loras == other.loras &&
  337. cross == other.cross &&
  338. n_outputs == other.n_outputs;
  339. }
  340. };
  341. class llm_graph_result {
  342. public:
  343. llm_graph_result(int64_t max_nodes);
  344. virtual ~llm_graph_result() = default;
  345. ggml_tensor * get_tokens() const { return t_tokens; }
  346. ggml_tensor * get_logits() const { return t_logits; }
  347. ggml_tensor * get_embd() const { return t_embd; }
  348. ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
  349. ggml_cgraph * get_gf() const { return gf; }
  350. ggml_context * get_ctx() const { return ctx_compute.get(); }
  351. int64_t get_max_nodes() const;
  352. void reset();
  353. void set_inputs(const llama_ubatch * ubatch);
  354. // try to update the existing graph result using the new graph parameters in order to reuse it
  355. // this can only be done if we determine that the resulting graph using the new graph parameters
  356. // would be identical to the existing graph. in that case, we simply have to update the memory
  357. // contexts of the input tensors of the graph and we can reuse it for another computation
  358. // return true if the graph was updated and can be reused
  359. bool can_reuse(const llm_graph_params & params);
  360. llm_graph_input_i * add_input(llm_graph_input_ptr input);
  361. void set_params(const llm_graph_params & params);
  362. // important graph nodes
  363. ggml_tensor * t_tokens = nullptr;
  364. ggml_tensor * t_logits = nullptr;
  365. ggml_tensor * t_embd = nullptr;
  366. ggml_tensor * t_embd_pooled = nullptr;
  367. std::vector<llm_graph_input_ptr> inputs;
  368. ggml_context_ptr ctx_compute;
  369. // memory buffers used to evaluate the model
  370. std::vector<uint8_t> buf_compute_meta;
  371. ggml_cgraph * gf;
  372. int64_t max_nodes;
  373. private:
  374. // keep a copy of the previous graph parameters
  375. // we will use this to determine whether the graph can be reused by comparing them with the new parameters
  376. // note: these are updated after constructing the new graph
  377. llm_graph_params params;
  378. // env: LLAMA_GRAPH_RESULT_DEBUG
  379. int debug = 0;
  380. };
  381. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
  382. //
  383. // llm_graph_context
  384. //
  385. // used in build_rs to properly order writes and avoid unnecessary copies
  386. using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
  387. struct llm_graph_context {
  388. const llm_arch arch;
  389. const llama_hparams & hparams;
  390. const llama_cparams & cparams;
  391. const llama_ubatch & ubatch;
  392. const int64_t n_embd;
  393. const int64_t n_layer;
  394. const int64_t n_rot;
  395. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  396. const int64_t n_head;
  397. const int64_t n_head_kv;
  398. const int64_t n_embd_head_k;
  399. const int64_t n_embd_k_gqa;
  400. const int64_t n_embd_head_v;
  401. const int64_t n_embd_v_gqa;
  402. const int64_t n_expert;
  403. const int64_t n_expert_used;
  404. const float freq_base;
  405. const float freq_scale;
  406. const float ext_factor;
  407. const float attn_factor;
  408. const float beta_fast;
  409. const float beta_slow;
  410. const float norm_eps;
  411. const float norm_rms_eps;
  412. const int64_t n_tokens;
  413. const int64_t n_outputs;
  414. const int32_t n_ctx_orig; // yarn
  415. const enum llama_pooling_type pooling_type;
  416. const enum llama_rope_type rope_type;
  417. ggml_backend_sched_t sched;
  418. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  419. const llama_adapter_cvec * cvec;
  420. const llama_adapter_loras * loras;
  421. const llama_memory_context_i * mctx;
  422. const llama_cross * cross;
  423. const llm_graph_cb & cb_func;
  424. llm_graph_result * res;
  425. ggml_context * ctx0 = nullptr;
  426. ggml_cgraph * gf = nullptr;
  427. llm_graph_context(const llm_graph_params & params);
  428. virtual ~llm_graph_context() = default;
  429. void cb(ggml_tensor * cur, const char * name, int il) const;
  430. //
  431. // common
  432. //
  433. ggml_tensor * build_cvec(
  434. ggml_tensor * cur,
  435. int il) const;
  436. // do mat_mul, while optionally apply lora
  437. ggml_tensor * build_lora_mm(
  438. ggml_tensor * w,
  439. ggml_tensor * cur) const;
  440. // do mat_mul_id, while optionally apply lora
  441. ggml_tensor * build_lora_mm_id(
  442. ggml_tensor * w, // ggml_tensor * as
  443. ggml_tensor * cur, // ggml_tensor * b
  444. ggml_tensor * ids) const;
  445. ggml_tensor * build_norm(
  446. ggml_tensor * cur,
  447. ggml_tensor * mw,
  448. ggml_tensor * mb,
  449. llm_norm_type type,
  450. int il) const;
  451. ggml_tensor * build_ffn(
  452. ggml_tensor * cur,
  453. ggml_tensor * up,
  454. ggml_tensor * up_b,
  455. ggml_tensor * up_s,
  456. ggml_tensor * gate,
  457. ggml_tensor * gate_b,
  458. ggml_tensor * gate_s,
  459. ggml_tensor * down,
  460. ggml_tensor * down_b,
  461. ggml_tensor * down_s,
  462. ggml_tensor * act_scales,
  463. llm_ffn_op_type type_op,
  464. llm_ffn_gate_type type_gate,
  465. int il) const;
  466. // build MoE FFN without bias tensors
  467. ggml_tensor * build_moe_ffn(
  468. ggml_tensor * cur,
  469. ggml_tensor * gate_inp,
  470. ggml_tensor * up_exps,
  471. ggml_tensor * gate_exps,
  472. ggml_tensor * down_exps,
  473. ggml_tensor * exp_probs_b,
  474. int64_t n_expert,
  475. int64_t n_expert_used,
  476. llm_ffn_op_type type_op,
  477. bool norm_w,
  478. bool scale_w,
  479. float w_scale,
  480. llama_expert_gating_func_type gating_op,
  481. int il,
  482. ggml_tensor * probs_in = nullptr) const;
  483. ggml_tensor * build_moe_ffn(
  484. ggml_tensor * cur,
  485. ggml_tensor * gate_inp,
  486. ggml_tensor * gate_inp_b,
  487. ggml_tensor * up_exps,
  488. ggml_tensor * up_exps_b,
  489. ggml_tensor * gate_exps,
  490. ggml_tensor * gate_exps_b,
  491. ggml_tensor * down_exps,
  492. ggml_tensor * down_exps_b,
  493. ggml_tensor * exp_probs_b,
  494. int64_t n_expert,
  495. int64_t n_expert_used,
  496. llm_ffn_op_type type_op,
  497. bool norm_w,
  498. bool scale_w,
  499. float w_scale,
  500. llama_expert_gating_func_type gating_op,
  501. int il,
  502. ggml_tensor * probs_in = nullptr) const;
  503. //
  504. // inputs
  505. //
  506. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  507. ggml_tensor * build_inp_pos() const;
  508. ggml_tensor * build_inp_attn_scale() const;
  509. ggml_tensor * build_inp_out_ids() const;
  510. ggml_tensor * build_inp_mean() const;
  511. ggml_tensor * build_inp_cls() const;
  512. ggml_tensor * build_inp_cross_embd() const;
  513. ggml_tensor * build_inp_pos_bucket_enc() const;
  514. ggml_tensor * build_inp_pos_bucket_dec() const;
  515. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  516. //
  517. // attention
  518. //
  519. ggml_tensor * build_attn_mha(
  520. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  521. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  522. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  523. ggml_tensor * kq_b,
  524. ggml_tensor * kq_mask,
  525. ggml_tensor * sinks, // [n_head_q]
  526. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  527. float kq_scale) const;
  528. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  529. ggml_tensor * build_attn(
  530. llm_graph_input_attn_no_cache * inp,
  531. ggml_tensor * wo,
  532. ggml_tensor * wo_b,
  533. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  534. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  535. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  536. ggml_tensor * kq_b,
  537. ggml_tensor * sinks, // [n_head_q]
  538. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  539. float kq_scale,
  540. int il) const;
  541. llm_graph_input_attn_kv * build_attn_inp_kv() const;
  542. ggml_tensor * build_attn(
  543. llm_graph_input_attn_kv * inp,
  544. ggml_tensor * wo,
  545. ggml_tensor * wo_b,
  546. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  547. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  548. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  549. ggml_tensor * kq_b,
  550. ggml_tensor * sinks, // [n_head_q]
  551. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  552. float kq_scale,
  553. int il) const;
  554. llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
  555. // note: if k_cur or v_cur are not provided, they will not be stored in the memory
  556. ggml_tensor * build_attn(
  557. llm_graph_input_attn_kv_iswa * inp,
  558. ggml_tensor * wo,
  559. ggml_tensor * wo_b,
  560. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  561. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
  562. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
  563. ggml_tensor * kq_b,
  564. ggml_tensor * sinks, // [n_head_q]
  565. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  566. float kq_scale,
  567. int il) const;
  568. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  569. ggml_tensor * build_attn(
  570. llm_graph_input_attn_cross * inp,
  571. ggml_tensor * wo,
  572. ggml_tensor * wo_b,
  573. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  574. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  575. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  576. ggml_tensor * kq_b,
  577. ggml_tensor * sinks, // [n_head_q]
  578. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  579. float kq_scale,
  580. int il) const;
  581. //
  582. // recurrent
  583. //
  584. // TODO: move this implementation to llama_memory_recurrent.
  585. // this is analogous to llama_kv_cache::cpy_k / cpy_v
  586. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  587. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  588. // `llama_memory_recurrent`
  589. ggml_tensor * build_rs(
  590. ggml_tensor * s,
  591. ggml_tensor * state_copy_main,
  592. ggml_tensor * state_copy_extra,
  593. int32_t state_size,
  594. int32_t n_seqs,
  595. uint32_t n_rs,
  596. uint32_t rs_head,
  597. uint32_t rs_size,
  598. int32_t rs_zero,
  599. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  600. llm_graph_input_rs * build_rs_inp() const;
  601. ggml_tensor * build_rs(
  602. llm_graph_input_rs * inp,
  603. ggml_tensor * s,
  604. int32_t state_size,
  605. int32_t n_seqs,
  606. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  607. ggml_tensor * build_rwkv_token_shift_load(
  608. llm_graph_input_rs * inp,
  609. const llama_ubatch & ubatch,
  610. int il) const;
  611. ggml_tensor * build_rwkv_token_shift_store(
  612. ggml_tensor * token_shift,
  613. const llama_ubatch & ubatch,
  614. int il) const;
  615. //
  616. // hybrid
  617. //
  618. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  619. //
  620. // pooling
  621. //
  622. void build_pooling(
  623. ggml_tensor * cls,
  624. ggml_tensor * cls_b,
  625. ggml_tensor * cls_out,
  626. ggml_tensor * cls_out_b) const;
  627. };
  628. // TODO: better name
  629. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);