1
0

llama-graph.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. struct llama_memory_context_i;
  16. class llama_kv_cache_unified_context;
  17. class llama_kv_cache_unified_iswa_context;
  18. class llama_memory_recurrent_context;
  19. class llama_memory_hybrid_context;
  20. // certain models (typically multi-modal) can produce different types of graphs
  21. enum llm_graph_type {
  22. LLM_GRAPH_TYPE_DEFAULT,
  23. LLM_GRAPH_TYPE_ENCODER,
  24. LLM_GRAPH_TYPE_DECODER,
  25. };
  26. enum llm_ffn_op_type {
  27. LLM_FFN_SILU,
  28. LLM_FFN_GELU,
  29. LLM_FFN_RELU,
  30. LLM_FFN_RELU_SQR,
  31. LLM_FFN_SWIGLU,
  32. LLM_FFN_GEGLU,
  33. LLM_FFN_REGLU,
  34. };
  35. enum llm_ffn_gate_type {
  36. LLM_FFN_SEQ,
  37. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  38. };
  39. enum llm_norm_type {
  40. LLM_NORM,
  41. LLM_NORM_RMS,
  42. LLM_NORM_GROUP,
  43. };
  44. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  45. struct llama_cross {
  46. // the output embeddings from the encoder as a ggml tensor
  47. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  48. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  49. //ggml_tensor * t_embd = nullptr;
  50. int64_t n_embd = 0;
  51. int64_t n_enc = 0;
  52. // embeddings data copied to host memory (tmp)
  53. std::vector<float> v_embd;
  54. // needed to construct the cross-attention mask in the decoder
  55. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  56. };
  57. //
  58. // llm_graph_input
  59. //
  60. class llm_graph_input_i {
  61. public:
  62. virtual ~llm_graph_input_i() = default;
  63. virtual void set_input(const llama_ubatch * ubatch) = 0;
  64. };
  65. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  66. class llm_graph_input_embd : public llm_graph_input_i {
  67. public:
  68. llm_graph_input_embd() = default;
  69. virtual ~llm_graph_input_embd() = default;
  70. void set_input(const llama_ubatch * ubatch) override;
  71. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  72. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  73. };
  74. class llm_graph_input_pos : public llm_graph_input_i {
  75. public:
  76. llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  77. virtual ~llm_graph_input_pos() = default;
  78. void set_input(const llama_ubatch * ubatch) override;
  79. ggml_tensor * pos = nullptr; // I32 [n_batch]
  80. const uint32_t n_pos_per_embd = 1;
  81. };
  82. // temperature tuning, used by llama4
  83. class llm_graph_input_attn_temp : public llm_graph_input_i {
  84. public:
  85. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  86. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  87. virtual ~llm_graph_input_attn_temp() = default;
  88. void set_input(const llama_ubatch * ubatch) override;
  89. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  90. const uint32_t n_attn_temp_floor_scale;
  91. const float f_attn_temp_scale;
  92. };
  93. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  94. public:
  95. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  96. virtual ~llm_graph_input_pos_bucket() = default;
  97. void set_input(const llama_ubatch * ubatch) override;
  98. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  99. const llama_hparams & hparams;
  100. };
  101. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  102. public:
  103. llm_graph_input_pos_bucket_kv(
  104. const llama_hparams & hparams,
  105. const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
  106. virtual ~llm_graph_input_pos_bucket_kv() = default;
  107. void set_input(const llama_ubatch * ubatch) override;
  108. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  109. const llama_hparams & hparams;
  110. const llama_kv_cache_unified_context * mctx;
  111. };
  112. class llm_graph_input_out_ids : public llm_graph_input_i {
  113. public:
  114. llm_graph_input_out_ids(
  115. const llama_hparams & hparams,
  116. const llama_cparams & cparams,
  117. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  118. virtual ~llm_graph_input_out_ids() = default;
  119. void set_input(const llama_ubatch * ubatch) override;
  120. ggml_tensor * out_ids; // I32 [n_outputs]
  121. const llama_hparams & hparams;
  122. const llama_cparams & cparams;
  123. const int32_t n_outputs;
  124. };
  125. class llm_graph_input_mean : public llm_graph_input_i {
  126. public:
  127. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  128. virtual ~llm_graph_input_mean() = default;
  129. void set_input(const llama_ubatch * ubatch) override;
  130. ggml_tensor * mean; // F32 [n_batch, n_batch]
  131. const llama_cparams & cparams;
  132. };
  133. class llm_graph_input_cls : public llm_graph_input_i {
  134. public:
  135. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  136. virtual ~llm_graph_input_cls() = default;
  137. void set_input(const llama_ubatch * ubatch) override;
  138. ggml_tensor * cls; // I32 [n_batch]
  139. const llama_cparams & cparams;
  140. };
  141. class llm_graph_input_rs : public llm_graph_input_i {
  142. public:
  143. llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
  144. virtual ~llm_graph_input_rs() = default;
  145. void set_input(const llama_ubatch * ubatch) override;
  146. ggml_tensor * s_copy; // I32 [kv_size]
  147. const llama_memory_recurrent_context * mctx;
  148. };
  149. class llm_graph_input_cross_embd : public llm_graph_input_i {
  150. public:
  151. llm_graph_input_cross_embd(
  152. const llama_cross * cross) : cross(cross) {}
  153. virtual ~llm_graph_input_cross_embd() = default;
  154. void set_input(const llama_ubatch * ubatch) override;
  155. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  156. const llama_cross * cross;
  157. };
  158. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  159. public:
  160. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  161. hparams(hparams),
  162. cparams(cparams) {
  163. }
  164. ~llm_graph_input_attn_no_cache() = default;
  165. void set_input(const llama_ubatch * ubatch) override;
  166. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  167. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
  168. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
  169. const llama_hparams & hparams;
  170. const llama_cparams & cparams;
  171. };
  172. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  173. public:
  174. llm_graph_input_attn_kv_unified(
  175. const llama_hparams & hparams,
  176. const llama_cparams & cparams,
  177. const llama_kv_cache_unified_context * mctx) :
  178. hparams(hparams),
  179. cparams(cparams),
  180. mctx(mctx) {
  181. }
  182. ~llm_graph_input_attn_kv_unified() = default;
  183. void set_input(const llama_ubatch * ubatch) override;
  184. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  185. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  186. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  187. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  188. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
  189. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
  190. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
  191. const llama_hparams & hparams;
  192. const llama_cparams & cparams;
  193. const llama_kv_cache_unified_context * mctx;
  194. };
  195. class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
  196. public:
  197. llm_graph_input_attn_kv_unified_iswa(
  198. const llama_hparams & hparams,
  199. const llama_cparams & cparams,
  200. const llama_kv_cache_unified_iswa_context * mctx) :
  201. hparams(hparams),
  202. cparams(cparams),
  203. mctx(mctx) {
  204. }
  205. ~llm_graph_input_attn_kv_unified_iswa() = default;
  206. void set_input(const llama_ubatch * ubatch) override;
  207. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  208. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  209. ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
  210. ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
  211. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  212. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  213. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  214. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
  215. ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
  216. ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
  217. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
  218. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
  219. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
  220. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
  221. const llama_hparams & hparams;
  222. const llama_cparams & cparams;
  223. const llama_kv_cache_unified_iswa_context * mctx;
  224. };
  225. class llm_graph_input_attn_cross : public llm_graph_input_i {
  226. public:
  227. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  228. ~llm_graph_input_attn_cross() = default;
  229. void set_input(const llama_ubatch * ubatch) override;
  230. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  231. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  232. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  233. const llama_cross * cross = nullptr;
  234. };
  235. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  236. public:
  237. llm_graph_input_mem_hybrid(
  238. const llama_hparams & hparams,
  239. const llama_cparams & cparams,
  240. const llama_memory_hybrid_context * mctx) :
  241. hparams(hparams),
  242. cparams(cparams),
  243. mctx(mctx) {
  244. }
  245. virtual ~llm_graph_input_mem_hybrid() = default;
  246. void set_input(const llama_ubatch * ubatch) override;
  247. ggml_tensor * s_copy; // I32 [kv_size]
  248. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  249. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  250. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  251. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  252. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
  253. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
  254. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
  255. const llama_hparams & hparams;
  256. const llama_cparams & cparams;
  257. const llama_memory_hybrid_context * mctx;
  258. };
  259. // TODO: remove this when ggml_scale_add is implemented
  260. class llm_graph_input_one : public llm_graph_input_i {
  261. public:
  262. llm_graph_input_one() {}
  263. virtual ~llm_graph_input_one() = default;
  264. void set_input(const llama_ubatch * ubatch) override;
  265. ggml_tensor * one = nullptr; // F32
  266. };
  267. //
  268. // llm_graph_result
  269. //
  270. // these objects deliver the result from the graph build process back to the llama_context
  271. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  272. // specific data, by calling the set_inputs() method
  273. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  274. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  275. class llm_graph_result_i {
  276. public:
  277. virtual ~llm_graph_result_i() = default;
  278. virtual ggml_tensor * get_tokens() = 0;
  279. virtual ggml_tensor * get_logits() = 0;
  280. virtual ggml_tensor * get_embd() = 0;
  281. virtual ggml_tensor * get_embd_pooled() = 0;
  282. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  283. };
  284. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  285. class llm_graph_result : public llm_graph_result_i {
  286. public:
  287. virtual ~llm_graph_result() = default;
  288. ggml_tensor * get_tokens() override { return t_tokens; }
  289. ggml_tensor * get_logits() override { return t_logits; }
  290. ggml_tensor * get_embd() override { return t_embd; }
  291. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  292. void set_inputs(const llama_ubatch * ubatch) override {
  293. for (auto & input : inputs) {
  294. input->set_input(ubatch);
  295. }
  296. }
  297. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  298. inputs.emplace_back(std::move(input));
  299. return inputs.back().get();
  300. }
  301. // important graph nodes
  302. ggml_tensor * t_tokens = nullptr;
  303. ggml_tensor * t_logits = nullptr;
  304. ggml_tensor * t_embd = nullptr;
  305. ggml_tensor * t_embd_pooled = nullptr;
  306. std::vector<llm_graph_input_ptr> inputs;
  307. };
  308. //
  309. // llm_graph_context
  310. //
  311. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  312. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  313. struct llm_graph_params {
  314. ggml_context * ctx;
  315. const llm_arch arch;
  316. const llama_hparams & hparams;
  317. const llama_cparams & cparams;
  318. const llama_ubatch & ubatch;
  319. ggml_backend_sched_t sched;
  320. ggml_backend_t backend_cpu;
  321. const llama_adapter_cvec * cvec;
  322. const llama_adapter_loras * loras;
  323. const llama_memory_context_i * mctx;
  324. const llama_cross * cross;
  325. uint32_t n_outputs;
  326. const llm_graph_cb & cb;
  327. };
  328. // used in build_rs to properly order writes and avoid unnecessary copies
  329. using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
  330. struct llm_graph_context {
  331. const llm_arch arch;
  332. const llama_hparams & hparams;
  333. const llama_cparams & cparams;
  334. const llama_ubatch & ubatch;
  335. const int64_t n_embd;
  336. const int64_t n_layer;
  337. const int64_t n_rot;
  338. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  339. const int64_t n_head;
  340. const int64_t n_head_kv;
  341. const int64_t n_embd_head_k;
  342. const int64_t n_embd_k_gqa;
  343. const int64_t n_embd_head_v;
  344. const int64_t n_embd_v_gqa;
  345. const int64_t n_expert;
  346. const int64_t n_expert_used;
  347. const float freq_base;
  348. const float freq_scale;
  349. const float ext_factor;
  350. const float attn_factor;
  351. const float beta_fast;
  352. const float beta_slow;
  353. const float norm_eps;
  354. const float norm_rms_eps;
  355. const int64_t n_tokens;
  356. const int64_t n_outputs;
  357. const int32_t n_ctx_orig; // yarn
  358. const enum llama_pooling_type pooling_type;
  359. const enum llama_rope_type rope_type;
  360. ggml_context * ctx0 = nullptr;
  361. ggml_backend_sched_t sched;
  362. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  363. const llama_adapter_cvec * cvec;
  364. const llama_adapter_loras * loras;
  365. const llama_memory_context_i * mctx;
  366. const llama_cross * cross;
  367. const llm_graph_cb & cb_func;
  368. std::unique_ptr<llm_graph_result> res;
  369. llm_graph_context(const llm_graph_params & params);
  370. virtual ~llm_graph_context() = default;
  371. void cb(ggml_tensor * cur, const char * name, int il) const;
  372. //
  373. // common
  374. //
  375. ggml_tensor * build_cvec(
  376. ggml_tensor * cur,
  377. int il) const;
  378. // do mat_mul, while optionally apply lora
  379. ggml_tensor * build_lora_mm(
  380. ggml_tensor * w,
  381. ggml_tensor * cur) const;
  382. // do mat_mul_id, while optionally apply lora
  383. ggml_tensor * build_lora_mm_id(
  384. ggml_tensor * w, // ggml_tensor * as
  385. ggml_tensor * cur, // ggml_tensor * b
  386. ggml_tensor * ids) const;
  387. ggml_tensor * build_norm(
  388. ggml_tensor * cur,
  389. ggml_tensor * mw,
  390. ggml_tensor * mb,
  391. llm_norm_type type,
  392. int il) const;
  393. ggml_tensor * build_ffn(
  394. ggml_tensor * cur,
  395. ggml_tensor * up,
  396. ggml_tensor * up_b,
  397. ggml_tensor * up_s,
  398. ggml_tensor * gate,
  399. ggml_tensor * gate_b,
  400. ggml_tensor * gate_s,
  401. ggml_tensor * down,
  402. ggml_tensor * down_b,
  403. ggml_tensor * down_s,
  404. ggml_tensor * act_scales,
  405. llm_ffn_op_type type_op,
  406. llm_ffn_gate_type type_gate,
  407. int il) const;
  408. ggml_tensor * build_moe_ffn(
  409. ggml_tensor * cur,
  410. ggml_tensor * gate_inp,
  411. ggml_tensor * up_exps,
  412. ggml_tensor * gate_exps,
  413. ggml_tensor * down_exps,
  414. ggml_tensor * exp_probs_b,
  415. int64_t n_expert,
  416. int64_t n_expert_used,
  417. llm_ffn_op_type type_op,
  418. bool norm_w,
  419. bool scale_w,
  420. float w_scale,
  421. llama_expert_gating_func_type gating_op,
  422. int il) const;
  423. //
  424. // inputs
  425. //
  426. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  427. ggml_tensor * build_inp_pos() const;
  428. ggml_tensor * build_inp_attn_scale() const;
  429. ggml_tensor * build_inp_out_ids() const;
  430. ggml_tensor * build_inp_mean() const;
  431. ggml_tensor * build_inp_cls() const;
  432. ggml_tensor * build_inp_cross_embd() const;
  433. ggml_tensor * build_inp_pos_bucket_enc() const;
  434. ggml_tensor * build_inp_pos_bucket_dec() const;
  435. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  436. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  437. //
  438. // attention
  439. //
  440. ggml_tensor * build_attn_mha(
  441. ggml_cgraph * gf,
  442. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  443. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  444. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  445. ggml_tensor * kq_b,
  446. ggml_tensor * kq_mask,
  447. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  448. float kq_scale) const;
  449. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  450. ggml_tensor * build_attn(
  451. llm_graph_input_attn_no_cache * inp,
  452. ggml_cgraph * gf,
  453. ggml_tensor * wo,
  454. ggml_tensor * wo_b,
  455. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  456. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  457. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  458. ggml_tensor * kq_b,
  459. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  460. float kq_scale,
  461. int il) const;
  462. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  463. ggml_tensor * build_attn(
  464. llm_graph_input_attn_kv_unified * inp,
  465. ggml_cgraph * gf,
  466. ggml_tensor * wo,
  467. ggml_tensor * wo_b,
  468. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  469. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  470. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  471. ggml_tensor * kq_b,
  472. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  473. float kq_scale,
  474. int il) const;
  475. llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
  476. // note: if k_cur or v_cur are not provided, they will not be stored in the memory
  477. ggml_tensor * build_attn(
  478. llm_graph_input_attn_kv_unified_iswa * inp,
  479. ggml_cgraph * gf,
  480. ggml_tensor * wo,
  481. ggml_tensor * wo_b,
  482. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  483. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
  484. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
  485. ggml_tensor * kq_b,
  486. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  487. float kq_scale,
  488. int il) const;
  489. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  490. ggml_tensor * build_attn(
  491. llm_graph_input_attn_cross * inp,
  492. ggml_cgraph * gf,
  493. ggml_tensor * wo,
  494. ggml_tensor * wo_b,
  495. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  496. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  497. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  498. ggml_tensor * kq_b,
  499. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  500. float kq_scale,
  501. int il) const;
  502. ggml_tensor * build_attn(
  503. llm_graph_input_mem_hybrid * inp,
  504. ggml_cgraph * gf,
  505. ggml_tensor * wo,
  506. ggml_tensor * wo_b,
  507. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  508. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  509. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  510. ggml_tensor * kq_b,
  511. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  512. float kq_scale,
  513. int il) const;
  514. //
  515. // recurrent
  516. //
  517. // TODO: avoid notion of "kv"
  518. // TODO: move this implementation to llama_memory_recurrent.
  519. // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
  520. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  521. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  522. // `llama_memory_recurrent`
  523. ggml_tensor * build_rs(
  524. ggml_cgraph * gf,
  525. ggml_tensor * s,
  526. ggml_tensor * state_copy,
  527. int32_t state_size,
  528. int32_t n_seqs,
  529. uint32_t n_kv,
  530. uint32_t kv_head,
  531. uint32_t kv_size,
  532. int32_t rs_zero,
  533. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  534. llm_graph_input_rs * build_rs_inp() const;
  535. ggml_tensor * build_rs(
  536. llm_graph_input_rs * inp,
  537. ggml_cgraph * gf,
  538. ggml_tensor * s,
  539. int32_t state_size,
  540. int32_t n_seqs,
  541. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  542. ggml_tensor * build_rs(
  543. llm_graph_input_mem_hybrid * inp,
  544. ggml_cgraph * gf,
  545. ggml_tensor * s,
  546. int32_t state_size,
  547. int32_t n_seqs,
  548. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  549. ggml_tensor * build_rwkv_token_shift_load(
  550. llm_graph_input_rs * inp,
  551. ggml_cgraph * gf,
  552. const llama_ubatch & ubatch,
  553. int il) const;
  554. ggml_tensor * build_rwkv_token_shift_store(
  555. ggml_tensor * token_shift,
  556. const llama_ubatch & ubatch,
  557. int il) const;
  558. //
  559. // pooling
  560. //
  561. void build_pooling(
  562. ggml_cgraph * gf,
  563. ggml_tensor * cls,
  564. ggml_tensor * cls_b,
  565. ggml_tensor * cls_out,
  566. ggml_tensor * cls_out_b) const;
  567. };
  568. // TODO: better name
  569. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);