llama-graph.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. struct llama_memory_state_i;
  16. class llama_kv_cache_unified_state;
  17. class llama_kv_cache_unified_iswa_state;
  18. class llama_kv_cache_recurrent_state;
  19. // certain models (typically multi-modal) can produce different types of graphs
  20. enum llm_graph_type {
  21. LLM_GRAPH_TYPE_DEFAULT,
  22. LLM_GRAPH_TYPE_ENCODER,
  23. LLM_GRAPH_TYPE_DECODER,
  24. };
  25. enum llm_ffn_op_type {
  26. LLM_FFN_SILU,
  27. LLM_FFN_GELU,
  28. LLM_FFN_RELU,
  29. LLM_FFN_RELU_SQR,
  30. LLM_FFN_SWIGLU,
  31. LLM_FFN_GEGLU,
  32. };
  33. enum llm_ffn_gate_type {
  34. LLM_FFN_SEQ,
  35. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  36. };
  37. enum llm_norm_type {
  38. LLM_NORM,
  39. LLM_NORM_RMS,
  40. LLM_NORM_GROUP,
  41. };
  42. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  43. struct llama_cross {
  44. // the output embeddings from the encoder as a ggml tensor
  45. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  46. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  47. //ggml_tensor * t_embd = nullptr;
  48. int64_t n_embd = 0;
  49. int64_t n_enc = 0;
  50. // embeddings data copied to host memory (tmp)
  51. std::vector<float> v_embd;
  52. // needed to construct the cross-attention mask in the decoder
  53. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  54. };
  55. //
  56. // llm_graph_input
  57. //
  58. class llm_graph_input_i {
  59. public:
  60. virtual ~llm_graph_input_i() = default;
  61. virtual void set_input(const llama_ubatch * ubatch) = 0;
  62. };
  63. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  64. class llm_graph_input_embd : public llm_graph_input_i {
  65. public:
  66. llm_graph_input_embd() = default;
  67. virtual ~llm_graph_input_embd() = default;
  68. void set_input(const llama_ubatch * ubatch) override;
  69. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  70. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  71. };
  72. class llm_graph_input_pos : public llm_graph_input_i {
  73. public:
  74. llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  75. virtual ~llm_graph_input_pos() = default;
  76. void set_input(const llama_ubatch * ubatch) override;
  77. ggml_tensor * pos = nullptr; // I32 [n_batch]
  78. const int64_t n_pos_per_embd = 1;
  79. };
  80. // temperature tuning, used by llama4
  81. class llm_graph_input_attn_temp : public llm_graph_input_i {
  82. public:
  83. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  84. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  85. virtual ~llm_graph_input_attn_temp() = default;
  86. void set_input(const llama_ubatch * ubatch) override;
  87. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  88. const uint32_t n_attn_temp_floor_scale;
  89. const float f_attn_temp_scale;
  90. };
  91. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  92. public:
  93. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  94. virtual ~llm_graph_input_pos_bucket() = default;
  95. void set_input(const llama_ubatch * ubatch) override;
  96. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  97. const llama_hparams & hparams;
  98. };
  99. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  100. public:
  101. llm_graph_input_pos_bucket_kv(
  102. const llama_hparams & hparams,
  103. const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
  104. virtual ~llm_graph_input_pos_bucket_kv() = default;
  105. void set_input(const llama_ubatch * ubatch) override;
  106. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  107. const llama_hparams & hparams;
  108. const llama_kv_cache_unified_state * kv_state;
  109. };
  110. class llm_graph_input_out_ids : public llm_graph_input_i {
  111. public:
  112. llm_graph_input_out_ids(
  113. const llama_hparams & hparams,
  114. const llama_cparams & cparams,
  115. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  116. virtual ~llm_graph_input_out_ids() = default;
  117. void set_input(const llama_ubatch * ubatch) override;
  118. ggml_tensor * out_ids; // I32 [n_outputs]
  119. const llama_hparams & hparams;
  120. const llama_cparams & cparams;
  121. const int32_t n_outputs;
  122. };
  123. class llm_graph_input_mean : public llm_graph_input_i {
  124. public:
  125. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  126. virtual ~llm_graph_input_mean() = default;
  127. void set_input(const llama_ubatch * ubatch) override;
  128. ggml_tensor * mean; // F32 [n_batch, n_batch]
  129. const llama_cparams & cparams;
  130. };
  131. class llm_graph_input_cls : public llm_graph_input_i {
  132. public:
  133. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  134. virtual ~llm_graph_input_cls() = default;
  135. void set_input(const llama_ubatch * ubatch) override;
  136. ggml_tensor * cls; // I32 [n_batch]
  137. const llama_cparams & cparams;
  138. };
  139. class llm_graph_input_s_copy : public llm_graph_input_i {
  140. public:
  141. llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
  142. virtual ~llm_graph_input_s_copy() = default;
  143. void set_input(const llama_ubatch * ubatch) override;
  144. ggml_tensor * s_copy; // I32 [kv_size]
  145. const llama_kv_cache_recurrent_state * kv_state;
  146. };
  147. class llm_graph_input_s_mask : public llm_graph_input_i {
  148. public:
  149. llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
  150. virtual ~llm_graph_input_s_mask() = default;
  151. void set_input(const llama_ubatch * ubatch) override;
  152. ggml_tensor * s_mask; // F32 [1, n_kv]
  153. const llama_kv_cache_recurrent_state * kv_state;
  154. };
  155. class llm_graph_input_cross_embd : public llm_graph_input_i {
  156. public:
  157. llm_graph_input_cross_embd(
  158. const llama_cross * cross) : cross(cross) {}
  159. virtual ~llm_graph_input_cross_embd() = default;
  160. void set_input(const llama_ubatch * ubatch) override;
  161. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  162. const llama_cross * cross;
  163. };
  164. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  165. public:
  166. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  167. hparams(hparams),
  168. cparams(cparams) {
  169. }
  170. ~llm_graph_input_attn_no_cache() = default;
  171. void set_input(const llama_ubatch * ubatch) override;
  172. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  173. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
  174. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
  175. const llama_hparams & hparams;
  176. const llama_cparams & cparams;
  177. };
  178. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  179. public:
  180. llm_graph_input_attn_kv_unified(
  181. const llama_hparams & hparams,
  182. const llama_cparams & cparams,
  183. const llama_kv_cache_unified_state * kv_state) :
  184. hparams(hparams),
  185. cparams(cparams),
  186. kv_state(kv_state) {
  187. }
  188. ~llm_graph_input_attn_kv_unified() = default;
  189. void set_input(const llama_ubatch * ubatch) override;
  190. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  191. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  192. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  193. const llama_hparams & hparams;
  194. const llama_cparams & cparams;
  195. const llama_kv_cache_unified_state * kv_state;
  196. };
  197. class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
  198. public:
  199. llm_graph_input_attn_kv_unified_iswa(
  200. const llama_hparams & hparams,
  201. const llama_cparams & cparams,
  202. const llama_kv_cache_unified_iswa_state * kv_state) :
  203. hparams(hparams),
  204. cparams(cparams),
  205. kv_state(kv_state) {
  206. }
  207. ~llm_graph_input_attn_kv_unified_iswa() = default;
  208. void set_input(const llama_ubatch * ubatch) override;
  209. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  210. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  211. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  212. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  213. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
  214. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
  215. const llama_hparams & hparams;
  216. const llama_cparams & cparams;
  217. const llama_kv_cache_unified_iswa_state * kv_state;
  218. };
  219. class llm_graph_input_attn_cross : public llm_graph_input_i {
  220. public:
  221. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  222. ~llm_graph_input_attn_cross() = default;
  223. void set_input(const llama_ubatch * ubatch) override;
  224. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  225. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
  226. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
  227. const llama_cross * cross = nullptr;
  228. };
  229. //
  230. // llm_graph_result
  231. //
  232. // these objects deliver the result from the graph build process back to the llama_context
  233. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  234. // specific data, by calling the set_inputs() method
  235. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  236. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  237. class llm_graph_result_i {
  238. public:
  239. virtual ~llm_graph_result_i() = default;
  240. virtual ggml_tensor * get_tokens() = 0;
  241. virtual ggml_tensor * get_logits() = 0;
  242. virtual ggml_tensor * get_embd() = 0;
  243. virtual ggml_tensor * get_embd_pooled() = 0;
  244. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  245. };
  246. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  247. class llm_graph_result : public llm_graph_result_i {
  248. public:
  249. virtual ~llm_graph_result() = default;
  250. ggml_tensor * get_tokens() override { return t_tokens; }
  251. ggml_tensor * get_logits() override { return t_logits; }
  252. ggml_tensor * get_embd() override { return t_embd; }
  253. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  254. void set_inputs(const llama_ubatch * ubatch) override {
  255. for (auto & input : inputs) {
  256. input->set_input(ubatch);
  257. }
  258. }
  259. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  260. inputs.emplace_back(std::move(input));
  261. return inputs.back().get();
  262. }
  263. // important graph nodes
  264. ggml_tensor * t_tokens = nullptr;
  265. ggml_tensor * t_logits = nullptr;
  266. ggml_tensor * t_embd = nullptr;
  267. ggml_tensor * t_embd_pooled = nullptr;
  268. std::vector<llm_graph_input_ptr> inputs;
  269. };
  270. //
  271. // llm_graph_context
  272. //
  273. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  274. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  275. struct llm_graph_params {
  276. ggml_context * ctx;
  277. const llm_arch arch;
  278. const llama_hparams & hparams;
  279. const llama_cparams & cparams;
  280. const llama_ubatch & ubatch;
  281. ggml_backend_sched_t sched;
  282. ggml_backend_t backend_cpu;
  283. const llama_adapter_cvec * cvec;
  284. const llama_adapter_loras * loras;
  285. const llama_memory_state_i * mstate;
  286. const llama_cross * cross;
  287. int32_t n_outputs;
  288. const llm_graph_cb & cb;
  289. };
  290. struct llm_graph_context {
  291. const llm_arch arch;
  292. const llama_hparams & hparams;
  293. const llama_cparams & cparams;
  294. const llama_ubatch & ubatch;
  295. const int64_t n_embd;
  296. const int64_t n_layer;
  297. const int64_t n_rot;
  298. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  299. const int64_t n_head;
  300. const int64_t n_head_kv;
  301. const int64_t n_embd_head_k;
  302. const int64_t n_embd_k_gqa;
  303. const int64_t n_embd_head_v;
  304. const int64_t n_embd_v_gqa;
  305. const int64_t n_expert;
  306. const int64_t n_expert_used;
  307. const float freq_base;
  308. const float freq_scale;
  309. const float ext_factor;
  310. const float attn_factor;
  311. const float beta_fast;
  312. const float beta_slow;
  313. const float norm_eps;
  314. const float norm_rms_eps;
  315. const int32_t n_tokens;
  316. const int32_t n_outputs;
  317. const int32_t n_ctx_orig; // yarn
  318. const enum llama_pooling_type pooling_type;
  319. const enum llama_rope_type rope_type;
  320. ggml_context * ctx0 = nullptr;
  321. ggml_backend_sched_t sched;
  322. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  323. const llama_adapter_cvec * cvec;
  324. const llama_adapter_loras * loras;
  325. const llama_memory_state_i * mstate;
  326. const llama_cross * cross;
  327. const llm_graph_cb & cb_func;
  328. std::unique_ptr<llm_graph_result> res;
  329. llm_graph_context(const llm_graph_params & params);
  330. int64_t n_pos_per_embd() const;
  331. void cb(ggml_tensor * cur, const char * name, int il) const;
  332. //
  333. // common
  334. //
  335. ggml_tensor * build_cvec(
  336. ggml_tensor * cur,
  337. int il) const;
  338. // do mat_mul, while optionally apply lora
  339. ggml_tensor * build_lora_mm(
  340. ggml_tensor * w,
  341. ggml_tensor * cur) const;
  342. // do mat_mul_id, while optionally apply lora
  343. ggml_tensor * build_lora_mm_id(
  344. ggml_tensor * w, // ggml_tensor * as
  345. ggml_tensor * cur, // ggml_tensor * b
  346. ggml_tensor * ids) const;
  347. ggml_tensor * build_norm(
  348. ggml_tensor * cur,
  349. ggml_tensor * mw,
  350. ggml_tensor * mb,
  351. llm_norm_type type,
  352. int il) const;
  353. ggml_tensor * build_ffn(
  354. ggml_tensor * cur,
  355. ggml_tensor * up,
  356. ggml_tensor * up_b,
  357. ggml_tensor * up_s,
  358. ggml_tensor * gate,
  359. ggml_tensor * gate_b,
  360. ggml_tensor * gate_s,
  361. ggml_tensor * down,
  362. ggml_tensor * down_b,
  363. ggml_tensor * down_s,
  364. ggml_tensor * act_scales,
  365. llm_ffn_op_type type_op,
  366. llm_ffn_gate_type type_gate,
  367. int il) const;
  368. ggml_tensor * build_moe_ffn(
  369. ggml_tensor * cur,
  370. ggml_tensor * gate_inp,
  371. ggml_tensor * up_exps,
  372. ggml_tensor * gate_exps,
  373. ggml_tensor * down_exps,
  374. ggml_tensor * exp_probs_b,
  375. int64_t n_expert,
  376. int64_t n_expert_used,
  377. llm_ffn_op_type type_op,
  378. bool norm_w,
  379. bool scale_w,
  380. float w_scale,
  381. llama_expert_gating_func_type gating_op,
  382. int il) const;
  383. //
  384. // inputs
  385. //
  386. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  387. ggml_tensor * build_inp_pos() const;
  388. ggml_tensor * build_inp_attn_scale() const;
  389. ggml_tensor * build_inp_out_ids() const;
  390. ggml_tensor * build_inp_mean() const;
  391. ggml_tensor * build_inp_cls() const;
  392. ggml_tensor * build_inp_s_copy() const;
  393. ggml_tensor * build_inp_s_mask() const;
  394. ggml_tensor * build_inp_cross_embd() const;
  395. ggml_tensor * build_inp_pos_bucket_enc() const;
  396. ggml_tensor * build_inp_pos_bucket_dec() const;
  397. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  398. //
  399. // attention
  400. //
  401. ggml_tensor * build_attn_mha(
  402. ggml_cgraph * gf,
  403. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  404. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  405. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  406. ggml_tensor * kq_b,
  407. ggml_tensor * kq_mask,
  408. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  409. float kq_scale) const;
  410. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  411. ggml_tensor * build_attn(
  412. llm_graph_input_attn_no_cache * inp,
  413. ggml_cgraph * gf,
  414. ggml_tensor * wo,
  415. ggml_tensor * wo_b,
  416. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  417. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  418. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  419. ggml_tensor * kq_b,
  420. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  421. float kq_scale,
  422. int il) const;
  423. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  424. ggml_tensor * build_attn(
  425. llm_graph_input_attn_kv_unified * inp,
  426. ggml_cgraph * gf,
  427. ggml_tensor * wo,
  428. ggml_tensor * wo_b,
  429. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  430. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  431. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  432. ggml_tensor * kq_b,
  433. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  434. float kq_scale,
  435. int il) const;
  436. llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
  437. ggml_tensor * build_attn(
  438. llm_graph_input_attn_kv_unified_iswa * inp,
  439. ggml_cgraph * gf,
  440. ggml_tensor * wo,
  441. ggml_tensor * wo_b,
  442. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  443. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  444. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  445. ggml_tensor * kq_b,
  446. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  447. float kq_scale,
  448. int il) const;
  449. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  450. ggml_tensor * build_attn(
  451. llm_graph_input_attn_cross * inp,
  452. ggml_cgraph * gf,
  453. ggml_tensor * wo,
  454. ggml_tensor * wo_b,
  455. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  456. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  457. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  458. ggml_tensor * kq_b,
  459. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  460. float kq_scale,
  461. int il) const;
  462. //
  463. // recurrent
  464. //
  465. ggml_tensor * build_copy_mask_state(
  466. ggml_cgraph * gf,
  467. ggml_tensor * s,
  468. ggml_tensor * state_copy,
  469. ggml_tensor * state_mask,
  470. int32_t n_state,
  471. int32_t n_seqs) const;
  472. ggml_tensor * build_rwkv_token_shift_load(
  473. ggml_cgraph * gf,
  474. ggml_tensor * state_copy,
  475. ggml_tensor * state_mask,
  476. const llama_ubatch & ubatch,
  477. int il) const;
  478. ggml_tensor * build_rwkv_token_shift_store(
  479. ggml_tensor * token_shift,
  480. const llama_ubatch & ubatch,
  481. int il) const;
  482. //
  483. // pooling
  484. //
  485. void build_pooling(
  486. ggml_cgraph * gf,
  487. ggml_tensor * cls,
  488. ggml_tensor * cls_b,
  489. ggml_tensor * cls_out,
  490. ggml_tensor * cls_out_b) const;
  491. };
  492. // TODO: better name
  493. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);