llama-graph.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-batch.h"
  4. #include "llama-hparams.h"
  5. #include "llama-adapter.h"
  6. #include <cstdint>
  7. #include <vector>
  8. #include <memory>
  9. #include <set>
  10. #include <functional>
  11. struct ggml_cgraph;
  12. struct ggml_context;
  13. struct ggml_tensor;
  14. struct llama_cparams;
  15. struct llama_memory_context_i;
  16. class llama_kv_cache_context;
  17. class llama_kv_cache_iswa_context;
  18. class llama_memory_recurrent_context;
  19. class llama_memory_hybrid_context;
  20. // certain models (typically multi-modal) can produce different types of graphs
  21. enum llm_graph_type {
  22. LLM_GRAPH_TYPE_DEFAULT,
  23. LLM_GRAPH_TYPE_ENCODER,
  24. LLM_GRAPH_TYPE_DECODER,
  25. };
  26. enum llm_ffn_op_type {
  27. LLM_FFN_SILU,
  28. LLM_FFN_GELU,
  29. LLM_FFN_RELU,
  30. LLM_FFN_RELU_SQR,
  31. LLM_FFN_SWIGLU,
  32. LLM_FFN_GEGLU,
  33. LLM_FFN_REGLU,
  34. LLM_FFN_SWIGLU_OAI_MOE,
  35. };
  36. enum llm_ffn_gate_type {
  37. LLM_FFN_SEQ,
  38. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  39. };
  40. enum llm_norm_type {
  41. LLM_NORM,
  42. LLM_NORM_RMS,
  43. LLM_NORM_GROUP,
  44. };
  45. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  46. struct llama_cross {
  47. // the output embeddings from the encoder as a ggml tensor
  48. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  49. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  50. //ggml_tensor * t_embd = nullptr;
  51. int64_t n_embd = 0;
  52. int64_t n_enc = 0;
  53. // embeddings data copied to host memory (tmp)
  54. std::vector<float> v_embd;
  55. // needed to construct the cross-attention mask in the decoder
  56. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  57. };
  58. struct llm_graph_params;
  59. //
  60. // llm_graph_input
  61. //
  62. class llm_graph_input_i {
  63. public:
  64. llm_graph_input_i() {
  65. const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
  66. debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
  67. }
  68. virtual ~llm_graph_input_i() = default;
  69. virtual void set_input(const llama_ubatch * ubatch) = 0;
  70. // return true if the resulting input tensors using the provided graph parameters would be
  71. // the same as the previous input tensors that we have currently stored in the object
  72. virtual bool can_reuse(const llm_graph_params & params) {
  73. // returning false here by default will prevent from reusing the graph if the check
  74. // for the input type has not been implemented yet
  75. GGML_UNUSED(params);
  76. return false;
  77. }
  78. protected:
  79. // env: LLAMA_GRAPH_INPUT_DEBUG
  80. int debug = 0;
  81. };
  82. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  83. class llm_graph_input_embd : public llm_graph_input_i {
  84. public:
  85. llm_graph_input_embd() = default;
  86. virtual ~llm_graph_input_embd() = default;
  87. void set_input(const llama_ubatch * ubatch) override;
  88. bool can_reuse(const llm_graph_params & params) override;
  89. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  90. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  91. };
  92. class llm_graph_input_pos : public llm_graph_input_i {
  93. public:
  94. llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  95. virtual ~llm_graph_input_pos() = default;
  96. void set_input(const llama_ubatch * ubatch) override;
  97. bool can_reuse(const llm_graph_params & params) override;
  98. ggml_tensor * pos = nullptr; // I32 [n_batch]
  99. const uint32_t n_pos_per_embd = 1;
  100. };
  101. // temperature tuning, used by llama4
  102. class llm_graph_input_attn_temp : public llm_graph_input_i {
  103. public:
  104. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
  105. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
  106. virtual ~llm_graph_input_attn_temp() = default;
  107. void set_input(const llama_ubatch * ubatch) override;
  108. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  109. const uint32_t n_attn_temp_floor_scale;
  110. const float f_attn_temp_scale;
  111. const float f_attn_temp_offset;
  112. };
  113. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  114. public:
  115. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  116. virtual ~llm_graph_input_pos_bucket() = default;
  117. void set_input(const llama_ubatch * ubatch) override;
  118. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  119. const llama_hparams hparams;
  120. };
  121. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  122. public:
  123. llm_graph_input_pos_bucket_kv(
  124. const llama_hparams & hparams,
  125. const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
  126. virtual ~llm_graph_input_pos_bucket_kv() = default;
  127. void set_input(const llama_ubatch * ubatch) override;
  128. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  129. const llama_hparams hparams;
  130. const llama_kv_cache_context * mctx;
  131. };
  132. class llm_graph_input_out_ids : public llm_graph_input_i {
  133. public:
  134. llm_graph_input_out_ids(
  135. const llama_hparams & hparams,
  136. const llama_cparams & cparams,
  137. uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  138. virtual ~llm_graph_input_out_ids() = default;
  139. void set_input(const llama_ubatch * ubatch) override;
  140. bool can_reuse(const llm_graph_params & params) override;
  141. ggml_tensor * out_ids; // I32 [n_outputs]
  142. const llama_hparams hparams;
  143. const llama_cparams cparams;
  144. const uint32_t n_outputs;
  145. };
  146. class llm_graph_input_mean : public llm_graph_input_i {
  147. public:
  148. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  149. virtual ~llm_graph_input_mean() = default;
  150. void set_input(const llama_ubatch * ubatch) override;
  151. ggml_tensor * mean; // F32 [n_batch, n_batch]
  152. const llama_cparams cparams;
  153. };
  154. class llm_graph_input_cls : public llm_graph_input_i {
  155. public:
  156. llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
  157. virtual ~llm_graph_input_cls() = default;
  158. void set_input(const llama_ubatch * ubatch) override;
  159. ggml_tensor * cls; // I32 [n_batch]
  160. const llama_cparams cparams;
  161. const llm_arch arch;
  162. };
  163. class llm_graph_input_rs : public llm_graph_input_i {
  164. public:
  165. llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
  166. virtual ~llm_graph_input_rs() = default;
  167. void set_input(const llama_ubatch * ubatch) override;
  168. bool can_reuse(const llm_graph_params & params) override;
  169. ggml_tensor * s_copy; // I32 [n_rs]
  170. // views of s_copy, computed once per graph
  171. // and shared across layers which use build_rs
  172. ggml_tensor * s_copy_main; // I32 [n_seqs]
  173. ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
  174. const llama_memory_recurrent_context * mctx;
  175. // used in view offsets, need to match for valid graph reuse
  176. uint32_t head;
  177. int32_t rs_z;
  178. };
  179. class llm_graph_input_cross_embd : public llm_graph_input_i {
  180. public:
  181. llm_graph_input_cross_embd(
  182. const llama_cross * cross) : cross(cross) {}
  183. virtual ~llm_graph_input_cross_embd() = default;
  184. void set_input(const llama_ubatch * ubatch) override;
  185. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  186. const llama_cross * cross;
  187. };
  188. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  189. public:
  190. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  191. hparams(hparams),
  192. cparams(cparams) {
  193. }
  194. ~llm_graph_input_attn_no_cache() = default;
  195. void set_input(const llama_ubatch * ubatch) override;
  196. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  197. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  198. // n_tokens == n_batch
  199. ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
  200. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
  201. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
  202. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
  203. const llama_hparams hparams;
  204. const llama_cparams cparams;
  205. };
  206. class llm_graph_input_attn_kv : public llm_graph_input_i {
  207. public:
  208. llm_graph_input_attn_kv(
  209. const llama_hparams & hparams,
  210. const llama_cparams & cparams,
  211. const llama_kv_cache_context * mctx) :
  212. hparams(hparams),
  213. cparams(cparams),
  214. mctx(mctx) {
  215. }
  216. ~llm_graph_input_attn_kv() = default;
  217. void set_input(const llama_ubatch * ubatch) override;
  218. bool can_reuse(const llm_graph_params & params) override;
  219. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  220. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  221. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  222. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  223. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  224. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  225. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  226. // note: these have to be copies because in order to be able to reuse a graph, its inputs
  227. // need to carry these parameters with them. otherwise, they can point to freed
  228. // llm_graph_params from a previous batch, causing stack-use-after-return
  229. const llama_hparams hparams;
  230. const llama_cparams cparams;
  231. const llama_kv_cache_context * mctx;
  232. };
  233. class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
  234. public:
  235. llm_graph_input_attn_kv_iswa(
  236. const llama_hparams & hparams,
  237. const llama_cparams & cparams,
  238. const llama_kv_cache_iswa_context * mctx) :
  239. hparams(hparams),
  240. cparams(cparams),
  241. mctx(mctx) {
  242. }
  243. ~llm_graph_input_attn_kv_iswa() = default;
  244. void set_input(const llama_ubatch * ubatch) override;
  245. bool can_reuse(const llm_graph_params & params) override;
  246. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  247. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  248. ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
  249. ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
  250. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  251. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  252. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  253. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  254. ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
  255. ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  256. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  257. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  258. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  259. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  260. const llama_hparams hparams;
  261. const llama_cparams cparams;
  262. const llama_kv_cache_iswa_context * mctx;
  263. };
  264. class llm_graph_input_attn_cross : public llm_graph_input_i {
  265. public:
  266. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  267. ~llm_graph_input_attn_cross() = default;
  268. void set_input(const llama_ubatch * ubatch) override;
  269. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  270. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  271. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  272. const llama_cross * cross = nullptr;
  273. };
  274. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  275. public:
  276. llm_graph_input_mem_hybrid(
  277. const llama_cparams & cparams,
  278. std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
  279. std::unique_ptr<llm_graph_input_rs> inp_rs,
  280. const llama_memory_hybrid_context * mctx) :
  281. inp_attn(std::move(inp_attn)),
  282. inp_rs(std::move(inp_rs)),
  283. cparams(cparams),
  284. mctx(mctx) { }
  285. virtual ~llm_graph_input_mem_hybrid() = default;
  286. void set_input(const llama_ubatch * ubatch) override;
  287. bool can_reuse(const llm_graph_params & params) override;
  288. std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
  289. std::unique_ptr<llm_graph_input_rs> inp_rs;
  290. llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
  291. llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
  292. const llama_cparams cparams;
  293. const llama_memory_hybrid_context * mctx;
  294. };
  295. //
  296. // llm_graph_result
  297. //
  298. // these objects deliver the result from the graph build process back to the llama_context
  299. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  300. // specific data, by calling the set_inputs() method
  301. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  302. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  303. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  304. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  305. class llm_graph_result;
  306. struct llm_graph_params {
  307. llm_arch arch = LLM_ARCH_UNKNOWN;
  308. llama_hparams hparams;
  309. llama_cparams cparams;
  310. llama_ubatch ubatch; // note: intentionally make a copy
  311. llm_graph_type gtype;
  312. ggml_backend_sched_t sched;
  313. ggml_backend_t backend_cpu;
  314. const llama_adapter_cvec * cvec;
  315. const llama_adapter_loras * loras;
  316. const llama_memory_context_i * mctx;
  317. const llama_cross * cross;
  318. uint32_t n_outputs;
  319. llm_graph_cb cb;
  320. llm_graph_result * res;
  321. // return true if the "other" params would result in a graph with the same topology as with the current params
  322. // having the same topology allows us to reuse the graph in some cases
  323. bool allow_reuse(const llm_graph_params & other) const {
  324. // first check the ubatch
  325. bool can_reuse_ubatch =
  326. ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
  327. ubatch.n_tokens == other.ubatch.n_tokens &&
  328. ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
  329. ubatch.n_seqs == other.ubatch.n_seqs &&
  330. ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
  331. (
  332. (!ubatch.token && !other.ubatch.token) ||
  333. (!ubatch.embd && !other.ubatch.embd)
  334. );
  335. // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
  336. // the reason is because the set of attention streams would be different for different sequences
  337. if (can_reuse_ubatch && ubatch.equal_seqs()) {
  338. if (!ubatch.data) {
  339. // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
  340. // therefore we cannot perform the sequence id check. normally should never happen
  341. can_reuse_ubatch = false;
  342. } else {
  343. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  344. can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
  345. }
  346. }
  347. }
  348. if (!can_reuse_ubatch) {
  349. return false;
  350. }
  351. return
  352. cparams.embeddings == other.cparams.embeddings &&
  353. cparams.causal_attn == other.cparams.causal_attn &&
  354. arch == other.arch &&
  355. gtype == other.gtype &&
  356. cvec == other.cvec &&
  357. loras == other.loras &&
  358. cross == other.cross &&
  359. n_outputs == other.n_outputs;
  360. }
  361. };
  362. class llm_graph_result {
  363. public:
  364. llm_graph_result(int64_t max_nodes);
  365. virtual ~llm_graph_result() = default;
  366. ggml_tensor * get_tokens() const { return t_tokens; }
  367. ggml_tensor * get_logits() const { return t_logits; }
  368. ggml_tensor * get_embd() const { return t_embd; }
  369. ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
  370. ggml_cgraph * get_gf() const { return gf; }
  371. ggml_context * get_ctx() const { return ctx_compute.get(); }
  372. int64_t get_max_nodes() const;
  373. void reset();
  374. void set_inputs(const llama_ubatch * ubatch);
  375. // try to update the existing graph result using the new graph parameters in order to reuse it
  376. // this can only be done if we determine that the resulting graph using the new graph parameters
  377. // would be identical to the existing graph. in that case, we simply have to update the memory
  378. // contexts of the input tensors of the graph and we can reuse it for another computation
  379. // return true if the graph was updated and can be reused
  380. bool can_reuse(const llm_graph_params & params);
  381. llm_graph_input_i * add_input(llm_graph_input_ptr input);
  382. void set_params(const llm_graph_params & params);
  383. // important graph nodes
  384. ggml_tensor * t_tokens = nullptr;
  385. ggml_tensor * t_logits = nullptr;
  386. ggml_tensor * t_embd = nullptr;
  387. ggml_tensor * t_embd_pooled = nullptr;
  388. std::vector<llm_graph_input_ptr> inputs;
  389. ggml_context_ptr ctx_compute;
  390. // memory buffers used to evaluate the model
  391. std::vector<uint8_t> buf_compute_meta;
  392. ggml_cgraph * gf;
  393. int64_t max_nodes;
  394. private:
  395. // keep a copy of the previous graph parameters
  396. // we will use this to determine whether the graph can be reused by comparing them with the new parameters
  397. // note: these are updated after constructing the new graph
  398. llm_graph_params params;
  399. // env: LLAMA_GRAPH_RESULT_DEBUG
  400. int debug = 0;
  401. };
  402. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
  403. //
  404. // llm_graph_context
  405. //
  406. // used in build_rs to properly order writes and avoid unnecessary copies
  407. using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
  408. struct llm_graph_context {
  409. const llm_arch arch;
  410. const llama_hparams & hparams;
  411. const llama_cparams & cparams;
  412. const llama_ubatch & ubatch;
  413. const int64_t n_embd;
  414. const int64_t n_layer;
  415. const int64_t n_rot;
  416. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  417. const int64_t n_head;
  418. const int64_t n_head_kv;
  419. const int64_t n_embd_head_k;
  420. const int64_t n_embd_k_gqa;
  421. const int64_t n_embd_head_v;
  422. const int64_t n_embd_v_gqa;
  423. const int64_t n_expert;
  424. const int64_t n_expert_used;
  425. const float freq_base;
  426. const float freq_scale;
  427. const float ext_factor;
  428. const float attn_factor;
  429. const float beta_fast;
  430. const float beta_slow;
  431. const float norm_eps;
  432. const float norm_rms_eps;
  433. const int64_t n_tokens;
  434. const int64_t n_outputs;
  435. const int32_t n_ctx_orig; // yarn
  436. const enum llama_pooling_type pooling_type;
  437. const enum llama_rope_type rope_type;
  438. ggml_backend_sched_t sched;
  439. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  440. const llama_adapter_cvec * cvec;
  441. const llama_adapter_loras * loras;
  442. const llama_memory_context_i * mctx;
  443. const llama_cross * cross;
  444. const llm_graph_cb & cb_func;
  445. llm_graph_result * res;
  446. ggml_context * ctx0 = nullptr;
  447. ggml_cgraph * gf = nullptr;
  448. llm_graph_context(const llm_graph_params & params);
  449. virtual ~llm_graph_context() = default;
  450. void cb(ggml_tensor * cur, const char * name, int il) const;
  451. //
  452. // common
  453. //
  454. ggml_tensor * build_cvec(
  455. ggml_tensor * cur,
  456. int il) const;
  457. // do mat_mul, while optionally apply lora
  458. ggml_tensor * build_lora_mm(
  459. ggml_tensor * w,
  460. ggml_tensor * cur) const;
  461. // do mat_mul_id, while optionally apply lora
  462. ggml_tensor * build_lora_mm_id(
  463. ggml_tensor * w, // ggml_tensor * as
  464. ggml_tensor * cur, // ggml_tensor * b
  465. ggml_tensor * ids) const;
  466. ggml_tensor * build_norm(
  467. ggml_tensor * cur,
  468. ggml_tensor * mw,
  469. ggml_tensor * mb,
  470. llm_norm_type type,
  471. int il) const;
  472. ggml_tensor * build_ffn(
  473. ggml_tensor * cur,
  474. ggml_tensor * up,
  475. ggml_tensor * up_b,
  476. ggml_tensor * up_s,
  477. ggml_tensor * gate,
  478. ggml_tensor * gate_b,
  479. ggml_tensor * gate_s,
  480. ggml_tensor * down,
  481. ggml_tensor * down_b,
  482. ggml_tensor * down_s,
  483. ggml_tensor * act_scales,
  484. llm_ffn_op_type type_op,
  485. llm_ffn_gate_type type_gate,
  486. int il) const;
  487. // build MoE FFN without bias tensors
  488. ggml_tensor * build_moe_ffn(
  489. ggml_tensor * cur,
  490. ggml_tensor * gate_inp,
  491. ggml_tensor * up_exps,
  492. ggml_tensor * gate_exps,
  493. ggml_tensor * down_exps,
  494. ggml_tensor * exp_probs_b,
  495. int64_t n_expert,
  496. int64_t n_expert_used,
  497. llm_ffn_op_type type_op,
  498. bool norm_w,
  499. bool scale_w,
  500. float w_scale,
  501. llama_expert_gating_func_type gating_op,
  502. int il,
  503. ggml_tensor * probs_in = nullptr) const;
  504. ggml_tensor * build_moe_ffn(
  505. ggml_tensor * cur,
  506. ggml_tensor * gate_inp,
  507. ggml_tensor * gate_inp_b,
  508. ggml_tensor * up_exps,
  509. ggml_tensor * up_exps_b,
  510. ggml_tensor * gate_exps,
  511. ggml_tensor * gate_exps_b,
  512. ggml_tensor * down_exps,
  513. ggml_tensor * down_exps_b,
  514. ggml_tensor * exp_probs_b,
  515. int64_t n_expert,
  516. int64_t n_expert_used,
  517. llm_ffn_op_type type_op,
  518. bool norm_w,
  519. bool scale_w,
  520. float w_scale,
  521. llama_expert_gating_func_type gating_op,
  522. int il,
  523. ggml_tensor * probs_in = nullptr) const;
  524. //
  525. // inputs
  526. //
  527. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  528. ggml_tensor * build_inp_pos() const;
  529. ggml_tensor * build_inp_attn_scale() const;
  530. ggml_tensor * build_inp_out_ids() const;
  531. ggml_tensor * build_inp_mean() const;
  532. ggml_tensor * build_inp_cls() const;
  533. ggml_tensor * build_inp_cross_embd() const;
  534. ggml_tensor * build_inp_pos_bucket_enc() const;
  535. ggml_tensor * build_inp_pos_bucket_dec() const;
  536. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  537. //
  538. // attention
  539. //
  540. ggml_tensor * build_attn_mha(
  541. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  542. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  543. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  544. ggml_tensor * kq_b,
  545. ggml_tensor * kq_mask,
  546. ggml_tensor * sinks, // [n_head_q]
  547. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  548. float kq_scale,
  549. int il) const;
  550. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  551. ggml_tensor * build_attn(
  552. llm_graph_input_attn_no_cache * inp,
  553. ggml_tensor * wo,
  554. ggml_tensor * wo_b,
  555. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  556. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  557. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  558. ggml_tensor * kq_b,
  559. ggml_tensor * sinks, // [n_head_q]
  560. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  561. float kq_scale,
  562. int il) const;
  563. llm_graph_input_attn_kv * build_attn_inp_kv() const;
  564. ggml_tensor * build_attn(
  565. llm_graph_input_attn_kv * inp,
  566. ggml_tensor * wo,
  567. ggml_tensor * wo_b,
  568. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  569. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  570. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  571. ggml_tensor * kq_b,
  572. ggml_tensor * sinks, // [n_head_q]
  573. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  574. float kq_scale,
  575. int il) const;
  576. llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
  577. // note: if k_cur or v_cur are not provided, they will not be stored in the memory
  578. ggml_tensor * build_attn(
  579. llm_graph_input_attn_kv_iswa * inp,
  580. ggml_tensor * wo,
  581. ggml_tensor * wo_b,
  582. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  583. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
  584. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
  585. ggml_tensor * kq_b,
  586. ggml_tensor * sinks, // [n_head_q]
  587. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  588. float kq_scale,
  589. int il) const;
  590. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  591. ggml_tensor * build_attn(
  592. llm_graph_input_attn_cross * inp,
  593. ggml_tensor * wo,
  594. ggml_tensor * wo_b,
  595. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  596. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  597. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  598. ggml_tensor * kq_b,
  599. ggml_tensor * sinks, // [n_head_q]
  600. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  601. float kq_scale,
  602. int il) const;
  603. //
  604. // recurrent
  605. //
  606. // TODO: move this implementation to llama_memory_recurrent.
  607. // this is analogous to llama_kv_cache::cpy_k / cpy_v
  608. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  609. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  610. // `llama_memory_recurrent`
  611. ggml_tensor * build_rs(
  612. ggml_tensor * s,
  613. ggml_tensor * state_copy_main,
  614. ggml_tensor * state_copy_extra,
  615. int32_t state_size,
  616. int32_t n_seqs,
  617. uint32_t n_rs,
  618. uint32_t rs_head,
  619. uint32_t rs_size,
  620. int32_t rs_zero,
  621. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  622. llm_graph_input_rs * build_rs_inp() const;
  623. ggml_tensor * build_rs(
  624. llm_graph_input_rs * inp,
  625. ggml_tensor * s,
  626. int32_t state_size,
  627. int32_t n_seqs,
  628. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  629. ggml_tensor * build_rwkv_token_shift_load(
  630. llm_graph_input_rs * inp,
  631. const llama_ubatch & ubatch,
  632. int il) const;
  633. ggml_tensor * build_rwkv_token_shift_store(
  634. ggml_tensor * token_shift,
  635. const llama_ubatch & ubatch,
  636. int il) const;
  637. //
  638. // hybrid
  639. //
  640. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  641. //
  642. // pooling
  643. //
  644. void build_pooling(
  645. ggml_tensor * cls,
  646. ggml_tensor * cls_b,
  647. ggml_tensor * cls_out,
  648. ggml_tensor * cls_out_b) const;
  649. //
  650. // dense (out)
  651. //
  652. void build_dense_out(
  653. ggml_tensor * dense_2,
  654. ggml_tensor * dense_3) const;
  655. };
  656. // TODO: better name
  657. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);