llama-graph.h 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-batch.h"
  4. #include "llama-hparams.h"
  5. #include "llama-adapter.h"
  6. #include <cstdint>
  7. #include <vector>
  8. #include <memory>
  9. #include <set>
  10. #include <functional>
  11. struct ggml_cgraph;
  12. struct ggml_context;
  13. struct ggml_tensor;
  14. struct llama_cparams;
  15. struct llama_memory_context_i;
  16. class llama_kv_cache_context;
  17. class llama_kv_cache_iswa_context;
  18. class llama_memory_recurrent_context;
  19. class llama_memory_hybrid_context;
  20. // certain models (typically multi-modal) can produce different types of graphs
  21. enum llm_graph_type {
  22. LLM_GRAPH_TYPE_DEFAULT,
  23. LLM_GRAPH_TYPE_ENCODER,
  24. LLM_GRAPH_TYPE_DECODER,
  25. };
  26. enum llm_ffn_op_type {
  27. LLM_FFN_SILU,
  28. LLM_FFN_GELU,
  29. LLM_FFN_RELU,
  30. LLM_FFN_RELU_SQR,
  31. LLM_FFN_SWIGLU,
  32. LLM_FFN_GEGLU,
  33. LLM_FFN_REGLU,
  34. LLM_FFN_SWIGLU_OAI_MOE,
  35. };
  36. enum llm_ffn_gate_type {
  37. LLM_FFN_SEQ,
  38. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  39. };
  40. enum llm_norm_type {
  41. LLM_NORM,
  42. LLM_NORM_RMS,
  43. LLM_NORM_GROUP,
  44. };
  45. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  46. struct llama_cross {
  47. // the output embeddings from the encoder as a ggml tensor
  48. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  49. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  50. //ggml_tensor * t_embd = nullptr;
  51. int64_t n_embd = 0;
  52. int64_t n_enc = 0;
  53. // embeddings data copied to host memory (tmp)
  54. std::vector<float> v_embd;
  55. // needed to construct the cross-attention mask in the decoder
  56. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  57. };
  58. struct llm_graph_params;
  59. //
  60. // llm_graph_input
  61. //
  62. class llm_graph_input_i {
  63. public:
  64. llm_graph_input_i() {
  65. const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
  66. debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
  67. }
  68. virtual ~llm_graph_input_i() = default;
  69. virtual void set_input(const llama_ubatch * ubatch) = 0;
  70. // return true if the resulting input tensors using the provided graph parameters would be
  71. // the same as the previous input tensors that we have currently stored in the object
  72. virtual bool can_reuse(const llm_graph_params & params) {
  73. // returning false here by default will prevent from reusing the graph if the check
  74. // for the input type has not been implemented yet
  75. GGML_UNUSED(params);
  76. return false;
  77. }
  78. protected:
  79. // env: LLAMA_GRAPH_INPUT_DEBUG
  80. int debug = 0;
  81. };
  82. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  83. class llm_graph_input_embd : public llm_graph_input_i {
  84. public:
  85. llm_graph_input_embd() = default;
  86. virtual ~llm_graph_input_embd() = default;
  87. void set_input(const llama_ubatch * ubatch) override;
  88. bool can_reuse(const llm_graph_params & params) override;
  89. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  90. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  91. };
  92. class llm_graph_input_pos : public llm_graph_input_i {
  93. public:
  94. llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  95. virtual ~llm_graph_input_pos() = default;
  96. void set_input(const llama_ubatch * ubatch) override;
  97. bool can_reuse(const llm_graph_params & params) override;
  98. ggml_tensor * pos = nullptr; // I32 [n_batch]
  99. const uint32_t n_pos_per_embd = 1;
  100. };
  101. // temperature tuning, used by llama4
  102. class llm_graph_input_attn_temp : public llm_graph_input_i {
  103. public:
  104. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  105. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  106. virtual ~llm_graph_input_attn_temp() = default;
  107. void set_input(const llama_ubatch * ubatch) override;
  108. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  109. const uint32_t n_attn_temp_floor_scale;
  110. const float f_attn_temp_scale;
  111. };
  112. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  113. public:
  114. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  115. virtual ~llm_graph_input_pos_bucket() = default;
  116. void set_input(const llama_ubatch * ubatch) override;
  117. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  118. const llama_hparams hparams;
  119. };
  120. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  121. public:
  122. llm_graph_input_pos_bucket_kv(
  123. const llama_hparams & hparams,
  124. const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
  125. virtual ~llm_graph_input_pos_bucket_kv() = default;
  126. void set_input(const llama_ubatch * ubatch) override;
  127. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  128. const llama_hparams hparams;
  129. const llama_kv_cache_context * mctx;
  130. };
  131. class llm_graph_input_out_ids : public llm_graph_input_i {
  132. public:
  133. llm_graph_input_out_ids(
  134. const llama_hparams & hparams,
  135. const llama_cparams & cparams,
  136. uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  137. virtual ~llm_graph_input_out_ids() = default;
  138. void set_input(const llama_ubatch * ubatch) override;
  139. bool can_reuse(const llm_graph_params & params) override;
  140. ggml_tensor * out_ids; // I32 [n_outputs]
  141. const llama_hparams hparams;
  142. const llama_cparams cparams;
  143. const uint32_t n_outputs;
  144. };
  145. class llm_graph_input_mean : public llm_graph_input_i {
  146. public:
  147. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  148. virtual ~llm_graph_input_mean() = default;
  149. void set_input(const llama_ubatch * ubatch) override;
  150. ggml_tensor * mean; // F32 [n_batch, n_batch]
  151. const llama_cparams cparams;
  152. };
  153. class llm_graph_input_cls : public llm_graph_input_i {
  154. public:
  155. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  156. virtual ~llm_graph_input_cls() = default;
  157. void set_input(const llama_ubatch * ubatch) override;
  158. ggml_tensor * cls; // I32 [n_batch]
  159. const llama_cparams cparams;
  160. };
  161. class llm_graph_input_rs : public llm_graph_input_i {
  162. public:
  163. llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
  164. virtual ~llm_graph_input_rs() = default;
  165. void set_input(const llama_ubatch * ubatch) override;
  166. ggml_tensor * s_copy; // I32 [n_rs]
  167. // views of s_copy, computed once per graph
  168. // and shared across layers which use build_rs
  169. ggml_tensor * s_copy_main; // I32 [n_seqs]
  170. ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
  171. const llama_memory_recurrent_context * mctx;
  172. };
  173. class llm_graph_input_cross_embd : public llm_graph_input_i {
  174. public:
  175. llm_graph_input_cross_embd(
  176. const llama_cross * cross) : cross(cross) {}
  177. virtual ~llm_graph_input_cross_embd() = default;
  178. void set_input(const llama_ubatch * ubatch) override;
  179. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  180. const llama_cross * cross;
  181. };
  182. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  183. public:
  184. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  185. hparams(hparams),
  186. cparams(cparams) {
  187. }
  188. ~llm_graph_input_attn_no_cache() = default;
  189. void set_input(const llama_ubatch * ubatch) override;
  190. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  191. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
  192. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
  193. const llama_hparams hparams;
  194. const llama_cparams cparams;
  195. };
  196. class llm_graph_input_attn_kv : public llm_graph_input_i {
  197. public:
  198. llm_graph_input_attn_kv(
  199. const llama_hparams & hparams,
  200. const llama_cparams & cparams,
  201. const llama_kv_cache_context * mctx) :
  202. hparams(hparams),
  203. cparams(cparams),
  204. mctx(mctx) {
  205. }
  206. ~llm_graph_input_attn_kv() = default;
  207. void set_input(const llama_ubatch * ubatch) override;
  208. bool can_reuse(const llm_graph_params & params) override;
  209. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  210. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  211. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  212. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  213. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  214. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  215. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  216. // note: these have to be copies because in order to be able to reuse a graph, its inputs
  217. // need to carry these parameters with them. otherwise, they can point to freed
  218. // llm_graph_params from a previous batch, causing stack-use-after-return
  219. const llama_hparams hparams;
  220. const llama_cparams cparams;
  221. const llama_kv_cache_context * mctx;
  222. };
  223. class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
  224. public:
  225. llm_graph_input_attn_kv_iswa(
  226. const llama_hparams & hparams,
  227. const llama_cparams & cparams,
  228. const llama_kv_cache_iswa_context * mctx) :
  229. hparams(hparams),
  230. cparams(cparams),
  231. mctx(mctx) {
  232. }
  233. ~llm_graph_input_attn_kv_iswa() = default;
  234. void set_input(const llama_ubatch * ubatch) override;
  235. bool can_reuse(const llm_graph_params & params) override;
  236. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  237. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  238. ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
  239. ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
  240. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  241. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  242. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  243. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  244. ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
  245. ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  246. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  247. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  248. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  249. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  250. const llama_hparams hparams;
  251. const llama_cparams cparams;
  252. const llama_kv_cache_iswa_context * mctx;
  253. };
  254. class llm_graph_input_attn_cross : public llm_graph_input_i {
  255. public:
  256. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  257. ~llm_graph_input_attn_cross() = default;
  258. void set_input(const llama_ubatch * ubatch) override;
  259. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  260. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  261. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  262. const llama_cross * cross = nullptr;
  263. };
  264. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  265. public:
  266. llm_graph_input_mem_hybrid(
  267. std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
  268. std::unique_ptr<llm_graph_input_rs> inp_rs,
  269. const llama_memory_hybrid_context * mctx) :
  270. inp_attn(std::move(inp_attn)),
  271. inp_rs(std::move(inp_rs)),
  272. mctx(mctx) { }
  273. virtual ~llm_graph_input_mem_hybrid() = default;
  274. void set_input(const llama_ubatch * ubatch) override;
  275. std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
  276. std::unique_ptr<llm_graph_input_rs> inp_rs;
  277. llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
  278. llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
  279. const llama_memory_hybrid_context * mctx;
  280. };
  281. //
  282. // llm_graph_result
  283. //
  284. // these objects deliver the result from the graph build process back to the llama_context
  285. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  286. // specific data, by calling the set_inputs() method
  287. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  288. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  289. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  290. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  291. class llm_graph_result;
  292. struct llm_graph_params {
  293. llm_arch arch = LLM_ARCH_UNKNOWN;
  294. llama_hparams hparams;
  295. llama_cparams cparams;
  296. llama_ubatch ubatch; // note: intentionally make a copy
  297. llm_graph_type gtype;
  298. ggml_backend_sched_t sched;
  299. ggml_backend_t backend_cpu;
  300. const llama_adapter_cvec * cvec;
  301. const llama_adapter_loras * loras;
  302. const llama_memory_context_i * mctx;
  303. const llama_cross * cross;
  304. uint32_t n_outputs;
  305. llm_graph_cb cb;
  306. llm_graph_result * res;
  307. // return true if the "other" params would result in a graph with the same topology as with the current params
  308. // having the same topology allows us to reuse the graph in some cases
  309. bool allow_reuse(const llm_graph_params & other) const {
  310. // first check the ubatch
  311. bool can_reuse_ubatch =
  312. ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
  313. ubatch.n_tokens == other.ubatch.n_tokens &&
  314. ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
  315. ubatch.n_seqs == other.ubatch.n_seqs &&
  316. ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
  317. (
  318. (!ubatch.token && !other.ubatch.token) ||
  319. (!ubatch.embd && !other.ubatch.embd)
  320. );
  321. // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
  322. // the reason is because the set of attention streams would be different for different sequences
  323. if (can_reuse_ubatch && ubatch.equal_seqs()) {
  324. if (!ubatch.data) {
  325. // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
  326. // therefore we cannot perform the sequence id check. normally should never happen
  327. can_reuse_ubatch = false;
  328. } else {
  329. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  330. can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
  331. }
  332. }
  333. }
  334. if (!can_reuse_ubatch) {
  335. return false;
  336. }
  337. return
  338. cparams.embeddings == other.cparams.embeddings &&
  339. cparams.causal_attn == other.cparams.causal_attn &&
  340. arch == other.arch &&
  341. gtype == other.gtype &&
  342. cvec == other.cvec &&
  343. loras == other.loras &&
  344. cross == other.cross &&
  345. n_outputs == other.n_outputs;
  346. }
  347. };
  348. class llm_graph_result {
  349. public:
  350. llm_graph_result(int64_t max_nodes);
  351. virtual ~llm_graph_result() = default;
  352. ggml_tensor * get_tokens() const { return t_tokens; }
  353. ggml_tensor * get_logits() const { return t_logits; }
  354. ggml_tensor * get_embd() const { return t_embd; }
  355. ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
  356. ggml_cgraph * get_gf() const { return gf; }
  357. ggml_context * get_ctx() const { return ctx_compute.get(); }
  358. int64_t get_max_nodes() const;
  359. void reset();
  360. void set_inputs(const llama_ubatch * ubatch);
  361. // try to update the existing graph result using the new graph parameters in order to reuse it
  362. // this can only be done if we determine that the resulting graph using the new graph parameters
  363. // would be identical to the existing graph. in that case, we simply have to update the memory
  364. // contexts of the input tensors of the graph and we can reuse it for another computation
  365. // return true if the graph was updated and can be reused
  366. bool can_reuse(const llm_graph_params & params);
  367. llm_graph_input_i * add_input(llm_graph_input_ptr input);
  368. void set_params(const llm_graph_params & params);
  369. // important graph nodes
  370. ggml_tensor * t_tokens = nullptr;
  371. ggml_tensor * t_logits = nullptr;
  372. ggml_tensor * t_embd = nullptr;
  373. ggml_tensor * t_embd_pooled = nullptr;
  374. std::vector<llm_graph_input_ptr> inputs;
  375. ggml_context_ptr ctx_compute;
  376. // memory buffers used to evaluate the model
  377. std::vector<uint8_t> buf_compute_meta;
  378. ggml_cgraph * gf;
  379. int64_t max_nodes;
  380. private:
  381. // keep a copy of the previous graph parameters
  382. // we will use this to determine whether the graph can be reused by comparing them with the new parameters
  383. // note: these are updated after constructing the new graph
  384. llm_graph_params params;
  385. // env: LLAMA_GRAPH_RESULT_DEBUG
  386. int debug = 0;
  387. };
  388. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
  389. //
  390. // llm_graph_context
  391. //
  392. // used in build_rs to properly order writes and avoid unnecessary copies
  393. using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
  394. struct llm_graph_context {
  395. const llm_arch arch;
  396. const llama_hparams & hparams;
  397. const llama_cparams & cparams;
  398. const llama_ubatch & ubatch;
  399. const int64_t n_embd;
  400. const int64_t n_layer;
  401. const int64_t n_rot;
  402. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  403. const int64_t n_head;
  404. const int64_t n_head_kv;
  405. const int64_t n_embd_head_k;
  406. const int64_t n_embd_k_gqa;
  407. const int64_t n_embd_head_v;
  408. const int64_t n_embd_v_gqa;
  409. const int64_t n_expert;
  410. const int64_t n_expert_used;
  411. const float freq_base;
  412. const float freq_scale;
  413. const float ext_factor;
  414. const float attn_factor;
  415. const float beta_fast;
  416. const float beta_slow;
  417. const float norm_eps;
  418. const float norm_rms_eps;
  419. const int64_t n_tokens;
  420. const int64_t n_outputs;
  421. const int32_t n_ctx_orig; // yarn
  422. const enum llama_pooling_type pooling_type;
  423. const enum llama_rope_type rope_type;
  424. ggml_backend_sched_t sched;
  425. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  426. const llama_adapter_cvec * cvec;
  427. const llama_adapter_loras * loras;
  428. const llama_memory_context_i * mctx;
  429. const llama_cross * cross;
  430. const llm_graph_cb & cb_func;
  431. llm_graph_result * res;
  432. ggml_context * ctx0 = nullptr;
  433. ggml_cgraph * gf = nullptr;
  434. llm_graph_context(const llm_graph_params & params);
  435. virtual ~llm_graph_context() = default;
  436. void cb(ggml_tensor * cur, const char * name, int il) const;
  437. //
  438. // common
  439. //
  440. ggml_tensor * build_cvec(
  441. ggml_tensor * cur,
  442. int il) const;
  443. // do mat_mul, while optionally apply lora
  444. ggml_tensor * build_lora_mm(
  445. ggml_tensor * w,
  446. ggml_tensor * cur) const;
  447. // do mat_mul_id, while optionally apply lora
  448. ggml_tensor * build_lora_mm_id(
  449. ggml_tensor * w, // ggml_tensor * as
  450. ggml_tensor * cur, // ggml_tensor * b
  451. ggml_tensor * ids) const;
  452. ggml_tensor * build_norm(
  453. ggml_tensor * cur,
  454. ggml_tensor * mw,
  455. ggml_tensor * mb,
  456. llm_norm_type type,
  457. int il) const;
  458. ggml_tensor * build_ffn(
  459. ggml_tensor * cur,
  460. ggml_tensor * up,
  461. ggml_tensor * up_b,
  462. ggml_tensor * up_s,
  463. ggml_tensor * gate,
  464. ggml_tensor * gate_b,
  465. ggml_tensor * gate_s,
  466. ggml_tensor * down,
  467. ggml_tensor * down_b,
  468. ggml_tensor * down_s,
  469. ggml_tensor * act_scales,
  470. llm_ffn_op_type type_op,
  471. llm_ffn_gate_type type_gate,
  472. int il) const;
  473. // build MoE FFN without bias tensors
  474. ggml_tensor * build_moe_ffn(
  475. ggml_tensor * cur,
  476. ggml_tensor * gate_inp,
  477. ggml_tensor * up_exps,
  478. ggml_tensor * gate_exps,
  479. ggml_tensor * down_exps,
  480. ggml_tensor * exp_probs_b,
  481. int64_t n_expert,
  482. int64_t n_expert_used,
  483. llm_ffn_op_type type_op,
  484. bool norm_w,
  485. bool scale_w,
  486. float w_scale,
  487. llama_expert_gating_func_type gating_op,
  488. int il,
  489. ggml_tensor * probs_in = nullptr) const;
  490. ggml_tensor * build_moe_ffn(
  491. ggml_tensor * cur,
  492. ggml_tensor * gate_inp,
  493. ggml_tensor * gate_inp_b,
  494. ggml_tensor * up_exps,
  495. ggml_tensor * up_exps_b,
  496. ggml_tensor * gate_exps,
  497. ggml_tensor * gate_exps_b,
  498. ggml_tensor * down_exps,
  499. ggml_tensor * down_exps_b,
  500. ggml_tensor * exp_probs_b,
  501. int64_t n_expert,
  502. int64_t n_expert_used,
  503. llm_ffn_op_type type_op,
  504. bool norm_w,
  505. bool scale_w,
  506. float w_scale,
  507. llama_expert_gating_func_type gating_op,
  508. int il,
  509. ggml_tensor * probs_in = nullptr) const;
  510. //
  511. // inputs
  512. //
  513. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  514. ggml_tensor * build_inp_pos() const;
  515. ggml_tensor * build_inp_attn_scale() const;
  516. ggml_tensor * build_inp_out_ids() const;
  517. ggml_tensor * build_inp_mean() const;
  518. ggml_tensor * build_inp_cls() const;
  519. ggml_tensor * build_inp_cross_embd() const;
  520. ggml_tensor * build_inp_pos_bucket_enc() const;
  521. ggml_tensor * build_inp_pos_bucket_dec() const;
  522. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  523. //
  524. // attention
  525. //
  526. ggml_tensor * build_attn_mha(
  527. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  528. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  529. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  530. ggml_tensor * kq_b,
  531. ggml_tensor * kq_mask,
  532. ggml_tensor * sinks, // [n_head_q]
  533. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  534. float kq_scale,
  535. int il) const;
  536. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  537. ggml_tensor * build_attn(
  538. llm_graph_input_attn_no_cache * inp,
  539. ggml_tensor * wo,
  540. ggml_tensor * wo_b,
  541. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  542. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  543. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  544. ggml_tensor * kq_b,
  545. ggml_tensor * sinks, // [n_head_q]
  546. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  547. float kq_scale,
  548. int il) const;
  549. llm_graph_input_attn_kv * build_attn_inp_kv() const;
  550. ggml_tensor * build_attn(
  551. llm_graph_input_attn_kv * inp,
  552. ggml_tensor * wo,
  553. ggml_tensor * wo_b,
  554. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  555. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  556. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  557. ggml_tensor * kq_b,
  558. ggml_tensor * sinks, // [n_head_q]
  559. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  560. float kq_scale,
  561. int il) const;
  562. llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
  563. // note: if k_cur or v_cur are not provided, they will not be stored in the memory
  564. ggml_tensor * build_attn(
  565. llm_graph_input_attn_kv_iswa * inp,
  566. ggml_tensor * wo,
  567. ggml_tensor * wo_b,
  568. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  569. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
  570. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
  571. ggml_tensor * kq_b,
  572. ggml_tensor * sinks, // [n_head_q]
  573. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  574. float kq_scale,
  575. int il) const;
  576. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  577. ggml_tensor * build_attn(
  578. llm_graph_input_attn_cross * inp,
  579. ggml_tensor * wo,
  580. ggml_tensor * wo_b,
  581. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  582. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  583. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  584. ggml_tensor * kq_b,
  585. ggml_tensor * sinks, // [n_head_q]
  586. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  587. float kq_scale,
  588. int il) const;
  589. //
  590. // recurrent
  591. //
  592. // TODO: move this implementation to llama_memory_recurrent.
  593. // this is analogous to llama_kv_cache::cpy_k / cpy_v
  594. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  595. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  596. // `llama_memory_recurrent`
  597. ggml_tensor * build_rs(
  598. ggml_tensor * s,
  599. ggml_tensor * state_copy_main,
  600. ggml_tensor * state_copy_extra,
  601. int32_t state_size,
  602. int32_t n_seqs,
  603. uint32_t n_rs,
  604. uint32_t rs_head,
  605. uint32_t rs_size,
  606. int32_t rs_zero,
  607. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  608. llm_graph_input_rs * build_rs_inp() const;
  609. ggml_tensor * build_rs(
  610. llm_graph_input_rs * inp,
  611. ggml_tensor * s,
  612. int32_t state_size,
  613. int32_t n_seqs,
  614. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  615. ggml_tensor * build_rwkv_token_shift_load(
  616. llm_graph_input_rs * inp,
  617. const llama_ubatch & ubatch,
  618. int il) const;
  619. ggml_tensor * build_rwkv_token_shift_store(
  620. ggml_tensor * token_shift,
  621. const llama_ubatch & ubatch,
  622. int il) const;
  623. //
  624. // hybrid
  625. //
  626. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  627. //
  628. // pooling
  629. //
  630. void build_pooling(
  631. ggml_tensor * cls,
  632. ggml_tensor * cls_b,
  633. ggml_tensor * cls_out,
  634. ggml_tensor * cls_out_b) const;
  635. };
  636. // TODO: better name
  637. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);