llama-graph.h 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. struct llama_memory_state_i;
  16. class llama_kv_cache_unified_state;
  17. class llama_kv_cache_unified_iswa_state;
  18. class llama_memory_recurrent_state;
  19. class llama_memory_hybrid_state;
  20. // certain models (typically multi-modal) can produce different types of graphs
  21. enum llm_graph_type {
  22. LLM_GRAPH_TYPE_DEFAULT,
  23. LLM_GRAPH_TYPE_ENCODER,
  24. LLM_GRAPH_TYPE_DECODER,
  25. };
  26. enum llm_ffn_op_type {
  27. LLM_FFN_SILU,
  28. LLM_FFN_GELU,
  29. LLM_FFN_RELU,
  30. LLM_FFN_RELU_SQR,
  31. LLM_FFN_SWIGLU,
  32. LLM_FFN_GEGLU,
  33. };
  34. enum llm_ffn_gate_type {
  35. LLM_FFN_SEQ,
  36. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  37. };
  38. enum llm_norm_type {
  39. LLM_NORM,
  40. LLM_NORM_RMS,
  41. LLM_NORM_GROUP,
  42. };
  43. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  44. struct llama_cross {
  45. // the output embeddings from the encoder as a ggml tensor
  46. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  47. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  48. //ggml_tensor * t_embd = nullptr;
  49. int64_t n_embd = 0;
  50. int64_t n_enc = 0;
  51. // embeddings data copied to host memory (tmp)
  52. std::vector<float> v_embd;
  53. // needed to construct the cross-attention mask in the decoder
  54. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  55. };
  56. //
  57. // llm_graph_input
  58. //
  59. class llm_graph_input_i {
  60. public:
  61. virtual ~llm_graph_input_i() = default;
  62. virtual void set_input(const llama_ubatch * ubatch) = 0;
  63. };
  64. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  65. class llm_graph_input_embd : public llm_graph_input_i {
  66. public:
  67. llm_graph_input_embd() = default;
  68. virtual ~llm_graph_input_embd() = default;
  69. void set_input(const llama_ubatch * ubatch) override;
  70. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  71. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  72. };
  73. class llm_graph_input_pos : public llm_graph_input_i {
  74. public:
  75. llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  76. virtual ~llm_graph_input_pos() = default;
  77. void set_input(const llama_ubatch * ubatch) override;
  78. ggml_tensor * pos = nullptr; // I32 [n_batch]
  79. const int64_t n_pos_per_embd = 1;
  80. };
  81. // temperature tuning, used by llama4
  82. class llm_graph_input_attn_temp : public llm_graph_input_i {
  83. public:
  84. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  85. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  86. virtual ~llm_graph_input_attn_temp() = default;
  87. void set_input(const llama_ubatch * ubatch) override;
  88. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  89. const uint32_t n_attn_temp_floor_scale;
  90. const float f_attn_temp_scale;
  91. };
  92. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  93. public:
  94. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  95. virtual ~llm_graph_input_pos_bucket() = default;
  96. void set_input(const llama_ubatch * ubatch) override;
  97. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  98. const llama_hparams & hparams;
  99. };
  100. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  101. public:
  102. llm_graph_input_pos_bucket_kv(
  103. const llama_hparams & hparams,
  104. const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
  105. virtual ~llm_graph_input_pos_bucket_kv() = default;
  106. void set_input(const llama_ubatch * ubatch) override;
  107. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  108. const llama_hparams & hparams;
  109. const llama_kv_cache_unified_state * kv_state;
  110. };
  111. class llm_graph_input_out_ids : public llm_graph_input_i {
  112. public:
  113. llm_graph_input_out_ids(
  114. const llama_hparams & hparams,
  115. const llama_cparams & cparams,
  116. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  117. virtual ~llm_graph_input_out_ids() = default;
  118. void set_input(const llama_ubatch * ubatch) override;
  119. ggml_tensor * out_ids; // I32 [n_outputs]
  120. const llama_hparams & hparams;
  121. const llama_cparams & cparams;
  122. const int32_t n_outputs;
  123. };
  124. class llm_graph_input_mean : public llm_graph_input_i {
  125. public:
  126. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  127. virtual ~llm_graph_input_mean() = default;
  128. void set_input(const llama_ubatch * ubatch) override;
  129. ggml_tensor * mean; // F32 [n_batch, n_batch]
  130. const llama_cparams & cparams;
  131. };
  132. class llm_graph_input_cls : public llm_graph_input_i {
  133. public:
  134. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  135. virtual ~llm_graph_input_cls() = default;
  136. void set_input(const llama_ubatch * ubatch) override;
  137. ggml_tensor * cls; // I32 [n_batch]
  138. const llama_cparams & cparams;
  139. };
  140. class llm_graph_input_rs : public llm_graph_input_i {
  141. public:
  142. llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
  143. virtual ~llm_graph_input_rs() = default;
  144. void set_input(const llama_ubatch * ubatch) override;
  145. ggml_tensor * s_copy; // I32 [kv_size]
  146. const llama_memory_recurrent_state * mem_state;
  147. };
  148. class llm_graph_input_cross_embd : public llm_graph_input_i {
  149. public:
  150. llm_graph_input_cross_embd(
  151. const llama_cross * cross) : cross(cross) {}
  152. virtual ~llm_graph_input_cross_embd() = default;
  153. void set_input(const llama_ubatch * ubatch) override;
  154. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  155. const llama_cross * cross;
  156. };
  157. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  158. public:
  159. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  160. hparams(hparams),
  161. cparams(cparams) {
  162. }
  163. ~llm_graph_input_attn_no_cache() = default;
  164. void set_input(const llama_ubatch * ubatch) override;
  165. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  166. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
  167. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
  168. const llama_hparams & hparams;
  169. const llama_cparams & cparams;
  170. };
  171. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  172. public:
  173. llm_graph_input_attn_kv_unified(
  174. const llama_hparams & hparams,
  175. const llama_cparams & cparams,
  176. const llama_kv_cache_unified_state * kv_state) :
  177. hparams(hparams),
  178. cparams(cparams),
  179. kv_state(kv_state) {
  180. }
  181. ~llm_graph_input_attn_kv_unified() = default;
  182. void set_input(const llama_ubatch * ubatch) override;
  183. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  184. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  185. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  186. const llama_hparams & hparams;
  187. const llama_cparams & cparams;
  188. const llama_kv_cache_unified_state * kv_state;
  189. };
  190. class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
  191. public:
  192. llm_graph_input_attn_kv_unified_iswa(
  193. const llama_hparams & hparams,
  194. const llama_cparams & cparams,
  195. const llama_kv_cache_unified_iswa_state * kv_state) :
  196. hparams(hparams),
  197. cparams(cparams),
  198. kv_state(kv_state) {
  199. }
  200. ~llm_graph_input_attn_kv_unified_iswa() = default;
  201. void set_input(const llama_ubatch * ubatch) override;
  202. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  203. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  204. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  205. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  206. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
  207. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
  208. const llama_hparams & hparams;
  209. const llama_cparams & cparams;
  210. const llama_kv_cache_unified_iswa_state * kv_state;
  211. };
  212. class llm_graph_input_attn_cross : public llm_graph_input_i {
  213. public:
  214. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  215. ~llm_graph_input_attn_cross() = default;
  216. void set_input(const llama_ubatch * ubatch) override;
  217. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  218. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
  219. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
  220. const llama_cross * cross = nullptr;
  221. };
  222. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  223. public:
  224. llm_graph_input_mem_hybrid(
  225. const llama_hparams & hparams,
  226. const llama_cparams & cparams,
  227. const llama_memory_hybrid_state * mem_state) :
  228. hparams(hparams),
  229. cparams(cparams),
  230. mem_state(mem_state) {
  231. }
  232. virtual ~llm_graph_input_mem_hybrid() = default;
  233. void set_input(const llama_ubatch * ubatch) override;
  234. ggml_tensor * s_copy; // I32 [kv_size]
  235. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  236. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  237. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  238. const llama_hparams & hparams;
  239. const llama_cparams & cparams;
  240. const llama_memory_hybrid_state * mem_state;
  241. };
  242. //
  243. // llm_graph_result
  244. //
  245. // these objects deliver the result from the graph build process back to the llama_context
  246. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  247. // specific data, by calling the set_inputs() method
  248. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  249. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  250. class llm_graph_result_i {
  251. public:
  252. virtual ~llm_graph_result_i() = default;
  253. virtual ggml_tensor * get_tokens() = 0;
  254. virtual ggml_tensor * get_logits() = 0;
  255. virtual ggml_tensor * get_embd() = 0;
  256. virtual ggml_tensor * get_embd_pooled() = 0;
  257. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  258. };
  259. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  260. class llm_graph_result : public llm_graph_result_i {
  261. public:
  262. virtual ~llm_graph_result() = default;
  263. ggml_tensor * get_tokens() override { return t_tokens; }
  264. ggml_tensor * get_logits() override { return t_logits; }
  265. ggml_tensor * get_embd() override { return t_embd; }
  266. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  267. void set_inputs(const llama_ubatch * ubatch) override {
  268. for (auto & input : inputs) {
  269. input->set_input(ubatch);
  270. }
  271. }
  272. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  273. inputs.emplace_back(std::move(input));
  274. return inputs.back().get();
  275. }
  276. // important graph nodes
  277. ggml_tensor * t_tokens = nullptr;
  278. ggml_tensor * t_logits = nullptr;
  279. ggml_tensor * t_embd = nullptr;
  280. ggml_tensor * t_embd_pooled = nullptr;
  281. std::vector<llm_graph_input_ptr> inputs;
  282. };
  283. //
  284. // llm_graph_context
  285. //
  286. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  287. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  288. struct llm_graph_params {
  289. ggml_context * ctx;
  290. const llm_arch arch;
  291. const llama_hparams & hparams;
  292. const llama_cparams & cparams;
  293. const llama_ubatch & ubatch;
  294. ggml_backend_sched_t sched;
  295. ggml_backend_t backend_cpu;
  296. const llama_adapter_cvec * cvec;
  297. const llama_adapter_loras * loras;
  298. const llama_memory_state_i * mstate;
  299. const llama_cross * cross;
  300. uint32_t n_outputs;
  301. const llm_graph_cb & cb;
  302. };
  303. struct llm_graph_context {
  304. const llm_arch arch;
  305. const llama_hparams & hparams;
  306. const llama_cparams & cparams;
  307. const llama_ubatch & ubatch;
  308. const int64_t n_embd;
  309. const int64_t n_layer;
  310. const int64_t n_rot;
  311. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  312. const int64_t n_head;
  313. const int64_t n_head_kv;
  314. const int64_t n_embd_head_k;
  315. const int64_t n_embd_k_gqa;
  316. const int64_t n_embd_head_v;
  317. const int64_t n_embd_v_gqa;
  318. const int64_t n_expert;
  319. const int64_t n_expert_used;
  320. const float freq_base;
  321. const float freq_scale;
  322. const float ext_factor;
  323. const float attn_factor;
  324. const float beta_fast;
  325. const float beta_slow;
  326. const float norm_eps;
  327. const float norm_rms_eps;
  328. const int64_t n_tokens;
  329. const int64_t n_outputs;
  330. const int32_t n_ctx_orig; // yarn
  331. const enum llama_pooling_type pooling_type;
  332. const enum llama_rope_type rope_type;
  333. ggml_context * ctx0 = nullptr;
  334. ggml_backend_sched_t sched;
  335. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  336. const llama_adapter_cvec * cvec;
  337. const llama_adapter_loras * loras;
  338. const llama_memory_state_i * mstate;
  339. const llama_cross * cross;
  340. const llm_graph_cb & cb_func;
  341. std::unique_ptr<llm_graph_result> res;
  342. llm_graph_context(const llm_graph_params & params);
  343. int64_t n_pos_per_embd() const;
  344. void cb(ggml_tensor * cur, const char * name, int il) const;
  345. //
  346. // common
  347. //
  348. ggml_tensor * build_cvec(
  349. ggml_tensor * cur,
  350. int il) const;
  351. // do mat_mul, while optionally apply lora
  352. ggml_tensor * build_lora_mm(
  353. ggml_tensor * w,
  354. ggml_tensor * cur) const;
  355. // do mat_mul_id, while optionally apply lora
  356. ggml_tensor * build_lora_mm_id(
  357. ggml_tensor * w, // ggml_tensor * as
  358. ggml_tensor * cur, // ggml_tensor * b
  359. ggml_tensor * ids) const;
  360. ggml_tensor * build_norm(
  361. ggml_tensor * cur,
  362. ggml_tensor * mw,
  363. ggml_tensor * mb,
  364. llm_norm_type type,
  365. int il) const;
  366. ggml_tensor * build_ffn(
  367. ggml_tensor * cur,
  368. ggml_tensor * up,
  369. ggml_tensor * up_b,
  370. ggml_tensor * up_s,
  371. ggml_tensor * gate,
  372. ggml_tensor * gate_b,
  373. ggml_tensor * gate_s,
  374. ggml_tensor * down,
  375. ggml_tensor * down_b,
  376. ggml_tensor * down_s,
  377. ggml_tensor * act_scales,
  378. llm_ffn_op_type type_op,
  379. llm_ffn_gate_type type_gate,
  380. int il) const;
  381. ggml_tensor * build_moe_ffn(
  382. ggml_tensor * cur,
  383. ggml_tensor * gate_inp,
  384. ggml_tensor * up_exps,
  385. ggml_tensor * gate_exps,
  386. ggml_tensor * down_exps,
  387. ggml_tensor * exp_probs_b,
  388. int64_t n_expert,
  389. int64_t n_expert_used,
  390. llm_ffn_op_type type_op,
  391. bool norm_w,
  392. bool scale_w,
  393. float w_scale,
  394. llama_expert_gating_func_type gating_op,
  395. int il) const;
  396. //
  397. // inputs
  398. //
  399. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  400. ggml_tensor * build_inp_pos() const;
  401. ggml_tensor * build_inp_attn_scale() const;
  402. ggml_tensor * build_inp_out_ids() const;
  403. ggml_tensor * build_inp_mean() const;
  404. ggml_tensor * build_inp_cls() const;
  405. ggml_tensor * build_inp_cross_embd() const;
  406. ggml_tensor * build_inp_pos_bucket_enc() const;
  407. ggml_tensor * build_inp_pos_bucket_dec() const;
  408. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  409. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  410. //
  411. // attention
  412. //
  413. ggml_tensor * build_attn_mha(
  414. ggml_cgraph * gf,
  415. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  416. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  417. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  418. ggml_tensor * kq_b,
  419. ggml_tensor * kq_mask,
  420. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  421. float kq_scale) const;
  422. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  423. ggml_tensor * build_attn(
  424. llm_graph_input_attn_no_cache * inp,
  425. ggml_cgraph * gf,
  426. ggml_tensor * wo,
  427. ggml_tensor * wo_b,
  428. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  429. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  430. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  431. ggml_tensor * kq_b,
  432. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  433. float kq_scale,
  434. int il) const;
  435. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  436. ggml_tensor * build_attn(
  437. llm_graph_input_attn_kv_unified * inp,
  438. ggml_cgraph * gf,
  439. ggml_tensor * wo,
  440. ggml_tensor * wo_b,
  441. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  442. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  443. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  444. ggml_tensor * kq_b,
  445. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  446. float kq_scale,
  447. int il) const;
  448. llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
  449. ggml_tensor * build_attn(
  450. llm_graph_input_attn_kv_unified_iswa * inp,
  451. ggml_cgraph * gf,
  452. ggml_tensor * wo,
  453. ggml_tensor * wo_b,
  454. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  455. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  456. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  457. ggml_tensor * kq_b,
  458. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  459. float kq_scale,
  460. int il) const;
  461. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  462. ggml_tensor * build_attn(
  463. llm_graph_input_attn_cross * inp,
  464. ggml_cgraph * gf,
  465. ggml_tensor * wo,
  466. ggml_tensor * wo_b,
  467. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  468. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  469. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  470. ggml_tensor * kq_b,
  471. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  472. float kq_scale,
  473. int il) const;
  474. ggml_tensor * build_attn(
  475. llm_graph_input_mem_hybrid * inp,
  476. ggml_cgraph * gf,
  477. ggml_tensor * wo,
  478. ggml_tensor * wo_b,
  479. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  480. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  481. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  482. ggml_tensor * kq_b,
  483. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  484. float kq_scale,
  485. int il) const;
  486. //
  487. // recurrent
  488. //
  489. // TODO: avoid notion of "kv"
  490. // TODO: move this implementation to llama_memory_recurrent.
  491. // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
  492. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  493. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  494. // `llama_memory_recurrent`
  495. ggml_tensor * build_rs(
  496. ggml_cgraph * gf,
  497. ggml_tensor * s,
  498. ggml_tensor * state_copy,
  499. int32_t state_size,
  500. int32_t n_seqs,
  501. uint32_t n_kv,
  502. uint32_t kv_head,
  503. uint32_t kv_size,
  504. int32_t rs_zero,
  505. bool avoid_copies = false) const;
  506. llm_graph_input_rs * build_rs_inp() const;
  507. ggml_tensor * build_rs(
  508. llm_graph_input_rs * inp,
  509. ggml_cgraph * gf,
  510. ggml_tensor * s,
  511. int32_t state_size,
  512. int32_t n_seqs,
  513. bool avoid_copies = false) const;
  514. ggml_tensor * build_rs(
  515. llm_graph_input_mem_hybrid * inp,
  516. ggml_cgraph * gf,
  517. ggml_tensor * s,
  518. int32_t state_size,
  519. int32_t n_seqs,
  520. bool avoid_copies = false) const;
  521. ggml_tensor * build_rwkv_token_shift_load(
  522. llm_graph_input_rs * inp,
  523. ggml_cgraph * gf,
  524. const llama_ubatch & ubatch,
  525. int il) const;
  526. ggml_tensor * build_rwkv_token_shift_store(
  527. ggml_tensor * token_shift,
  528. const llama_ubatch & ubatch,
  529. int il) const;
  530. //
  531. // pooling
  532. //
  533. void build_pooling(
  534. ggml_cgraph * gf,
  535. ggml_tensor * cls,
  536. ggml_tensor * cls_b,
  537. ggml_tensor * cls_out,
  538. ggml_tensor * cls_out_b) const;
  539. };
  540. // TODO: better name
  541. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);