llama-graph.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. struct llama_memory_state_i;
  16. class llama_kv_cache_unified_state;
  17. class llama_kv_cache_unified_iswa_state;
  18. class llama_kv_cache_recurrent_state;
  19. // certain models (typically multi-modal) can produce different types of graphs
  20. enum llm_graph_type {
  21. LLM_GRAPH_TYPE_DEFAULT,
  22. LLM_GRAPH_TYPE_ENCODER,
  23. LLM_GRAPH_TYPE_DECODER,
  24. };
  25. enum llm_ffn_op_type {
  26. LLM_FFN_SILU,
  27. LLM_FFN_GELU,
  28. LLM_FFN_RELU,
  29. LLM_FFN_RELU_SQR,
  30. LLM_FFN_SWIGLU,
  31. LLM_FFN_GEGLU,
  32. };
  33. enum llm_ffn_gate_type {
  34. LLM_FFN_SEQ,
  35. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  36. };
  37. enum llm_norm_type {
  38. LLM_NORM,
  39. LLM_NORM_RMS,
  40. LLM_NORM_GROUP,
  41. };
  42. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  43. struct llama_cross {
  44. // the output embeddings from the encoder as a ggml tensor
  45. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  46. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  47. //ggml_tensor * t_embd = nullptr;
  48. int64_t n_embd = 0;
  49. int64_t n_enc = 0;
  50. // embeddings data copied to host memory (tmp)
  51. std::vector<float> v_embd;
  52. // needed to construct the cross-attention mask in the decoder
  53. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  54. };
  55. //
  56. // llm_graph_input
  57. //
  58. class llm_graph_input_i {
  59. public:
  60. virtual ~llm_graph_input_i() = default;
  61. virtual void set_input(const llama_ubatch * ubatch) = 0;
  62. };
  63. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  64. class llm_graph_input_embd : public llm_graph_input_i {
  65. public:
  66. llm_graph_input_embd() = default;
  67. virtual ~llm_graph_input_embd() = default;
  68. void set_input(const llama_ubatch * ubatch) override;
  69. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  70. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  71. };
  72. class llm_graph_input_pos : public llm_graph_input_i {
  73. public:
  74. llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  75. virtual ~llm_graph_input_pos() = default;
  76. void set_input(const llama_ubatch * ubatch) override;
  77. ggml_tensor * pos = nullptr; // I32 [n_batch]
  78. const int64_t n_pos_per_embd = 1;
  79. };
  80. // temperature tuning, used by llama4
  81. class llm_graph_input_attn_temp : public llm_graph_input_i {
  82. public:
  83. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  84. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  85. virtual ~llm_graph_input_attn_temp() = default;
  86. void set_input(const llama_ubatch * ubatch) override;
  87. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  88. const uint32_t n_attn_temp_floor_scale;
  89. const float f_attn_temp_scale;
  90. };
  91. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  92. public:
  93. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  94. virtual ~llm_graph_input_pos_bucket() = default;
  95. void set_input(const llama_ubatch * ubatch) override;
  96. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  97. const llama_hparams & hparams;
  98. };
  99. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  100. public:
  101. llm_graph_input_pos_bucket_kv(
  102. const llama_hparams & hparams,
  103. const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
  104. virtual ~llm_graph_input_pos_bucket_kv() = default;
  105. void set_input(const llama_ubatch * ubatch) override;
  106. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  107. const llama_hparams & hparams;
  108. const llama_kv_cache_unified_state * kv_state;
  109. };
  110. class llm_graph_input_out_ids : public llm_graph_input_i {
  111. public:
  112. llm_graph_input_out_ids(
  113. const llama_hparams & hparams,
  114. const llama_cparams & cparams,
  115. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  116. virtual ~llm_graph_input_out_ids() = default;
  117. void set_input(const llama_ubatch * ubatch) override;
  118. ggml_tensor * out_ids; // I32 [n_outputs]
  119. const llama_hparams & hparams;
  120. const llama_cparams & cparams;
  121. const int32_t n_outputs;
  122. };
  123. class llm_graph_input_mean : public llm_graph_input_i {
  124. public:
  125. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  126. virtual ~llm_graph_input_mean() = default;
  127. void set_input(const llama_ubatch * ubatch) override;
  128. ggml_tensor * mean; // F32 [n_batch, n_batch]
  129. const llama_cparams & cparams;
  130. };
  131. class llm_graph_input_cls : public llm_graph_input_i {
  132. public:
  133. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  134. virtual ~llm_graph_input_cls() = default;
  135. void set_input(const llama_ubatch * ubatch) override;
  136. ggml_tensor * cls; // I32 [n_batch]
  137. const llama_cparams & cparams;
  138. };
  139. class llm_graph_input_s_copy : public llm_graph_input_i {
  140. public:
  141. llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
  142. virtual ~llm_graph_input_s_copy() = default;
  143. void set_input(const llama_ubatch * ubatch) override;
  144. ggml_tensor * s_copy; // I32 [kv_size]
  145. const llama_kv_cache_recurrent_state * kv_state;
  146. };
  147. class llm_graph_input_cross_embd : public llm_graph_input_i {
  148. public:
  149. llm_graph_input_cross_embd(
  150. const llama_cross * cross) : cross(cross) {}
  151. virtual ~llm_graph_input_cross_embd() = default;
  152. void set_input(const llama_ubatch * ubatch) override;
  153. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  154. const llama_cross * cross;
  155. };
  156. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  157. public:
  158. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  159. hparams(hparams),
  160. cparams(cparams) {
  161. }
  162. ~llm_graph_input_attn_no_cache() = default;
  163. void set_input(const llama_ubatch * ubatch) override;
  164. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  165. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
  166. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
  167. const llama_hparams & hparams;
  168. const llama_cparams & cparams;
  169. };
  170. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  171. public:
  172. llm_graph_input_attn_kv_unified(
  173. const llama_hparams & hparams,
  174. const llama_cparams & cparams,
  175. const llama_kv_cache_unified_state * kv_state) :
  176. hparams(hparams),
  177. cparams(cparams),
  178. kv_state(kv_state) {
  179. }
  180. ~llm_graph_input_attn_kv_unified() = default;
  181. void set_input(const llama_ubatch * ubatch) override;
  182. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  183. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  184. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  185. const llama_hparams & hparams;
  186. const llama_cparams & cparams;
  187. const llama_kv_cache_unified_state * kv_state;
  188. };
  189. class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
  190. public:
  191. llm_graph_input_attn_kv_unified_iswa(
  192. const llama_hparams & hparams,
  193. const llama_cparams & cparams,
  194. const llama_kv_cache_unified_iswa_state * kv_state) :
  195. hparams(hparams),
  196. cparams(cparams),
  197. kv_state(kv_state) {
  198. }
  199. ~llm_graph_input_attn_kv_unified_iswa() = default;
  200. void set_input(const llama_ubatch * ubatch) override;
  201. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  202. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  203. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  204. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  205. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
  206. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
  207. const llama_hparams & hparams;
  208. const llama_cparams & cparams;
  209. const llama_kv_cache_unified_iswa_state * kv_state;
  210. };
  211. class llm_graph_input_attn_cross : public llm_graph_input_i {
  212. public:
  213. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  214. ~llm_graph_input_attn_cross() = default;
  215. void set_input(const llama_ubatch * ubatch) override;
  216. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  217. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
  218. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
  219. const llama_cross * cross = nullptr;
  220. };
  221. //
  222. // llm_graph_result
  223. //
  224. // these objects deliver the result from the graph build process back to the llama_context
  225. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  226. // specific data, by calling the set_inputs() method
  227. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  228. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  229. class llm_graph_result_i {
  230. public:
  231. virtual ~llm_graph_result_i() = default;
  232. virtual ggml_tensor * get_tokens() = 0;
  233. virtual ggml_tensor * get_logits() = 0;
  234. virtual ggml_tensor * get_embd() = 0;
  235. virtual ggml_tensor * get_embd_pooled() = 0;
  236. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  237. };
  238. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  239. class llm_graph_result : public llm_graph_result_i {
  240. public:
  241. virtual ~llm_graph_result() = default;
  242. ggml_tensor * get_tokens() override { return t_tokens; }
  243. ggml_tensor * get_logits() override { return t_logits; }
  244. ggml_tensor * get_embd() override { return t_embd; }
  245. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  246. void set_inputs(const llama_ubatch * ubatch) override {
  247. for (auto & input : inputs) {
  248. input->set_input(ubatch);
  249. }
  250. }
  251. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  252. inputs.emplace_back(std::move(input));
  253. return inputs.back().get();
  254. }
  255. // important graph nodes
  256. ggml_tensor * t_tokens = nullptr;
  257. ggml_tensor * t_logits = nullptr;
  258. ggml_tensor * t_embd = nullptr;
  259. ggml_tensor * t_embd_pooled = nullptr;
  260. std::vector<llm_graph_input_ptr> inputs;
  261. };
  262. //
  263. // llm_graph_context
  264. //
  265. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  266. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  267. struct llm_graph_params {
  268. ggml_context * ctx;
  269. const llm_arch arch;
  270. const llama_hparams & hparams;
  271. const llama_cparams & cparams;
  272. const llama_ubatch & ubatch;
  273. ggml_backend_sched_t sched;
  274. ggml_backend_t backend_cpu;
  275. const llama_adapter_cvec * cvec;
  276. const llama_adapter_loras * loras;
  277. const llama_memory_state_i * mstate;
  278. const llama_cross * cross;
  279. int32_t n_outputs;
  280. const llm_graph_cb & cb;
  281. };
  282. struct llm_graph_context {
  283. const llm_arch arch;
  284. const llama_hparams & hparams;
  285. const llama_cparams & cparams;
  286. const llama_ubatch & ubatch;
  287. const int64_t n_embd;
  288. const int64_t n_layer;
  289. const int64_t n_rot;
  290. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  291. const int64_t n_head;
  292. const int64_t n_head_kv;
  293. const int64_t n_embd_head_k;
  294. const int64_t n_embd_k_gqa;
  295. const int64_t n_embd_head_v;
  296. const int64_t n_embd_v_gqa;
  297. const int64_t n_expert;
  298. const int64_t n_expert_used;
  299. const float freq_base;
  300. const float freq_scale;
  301. const float ext_factor;
  302. const float attn_factor;
  303. const float beta_fast;
  304. const float beta_slow;
  305. const float norm_eps;
  306. const float norm_rms_eps;
  307. const int32_t n_tokens;
  308. const int32_t n_outputs;
  309. const int32_t n_ctx_orig; // yarn
  310. const enum llama_pooling_type pooling_type;
  311. const enum llama_rope_type rope_type;
  312. ggml_context * ctx0 = nullptr;
  313. ggml_backend_sched_t sched;
  314. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  315. const llama_adapter_cvec * cvec;
  316. const llama_adapter_loras * loras;
  317. const llama_memory_state_i * mstate;
  318. const llama_cross * cross;
  319. const llm_graph_cb & cb_func;
  320. std::unique_ptr<llm_graph_result> res;
  321. llm_graph_context(const llm_graph_params & params);
  322. int64_t n_pos_per_embd() const;
  323. void cb(ggml_tensor * cur, const char * name, int il) const;
  324. //
  325. // common
  326. //
  327. ggml_tensor * build_cvec(
  328. ggml_tensor * cur,
  329. int il) const;
  330. // do mat_mul, while optionally apply lora
  331. ggml_tensor * build_lora_mm(
  332. ggml_tensor * w,
  333. ggml_tensor * cur) const;
  334. // do mat_mul_id, while optionally apply lora
  335. ggml_tensor * build_lora_mm_id(
  336. ggml_tensor * w, // ggml_tensor * as
  337. ggml_tensor * cur, // ggml_tensor * b
  338. ggml_tensor * ids) const;
  339. ggml_tensor * build_norm(
  340. ggml_tensor * cur,
  341. ggml_tensor * mw,
  342. ggml_tensor * mb,
  343. llm_norm_type type,
  344. int il) const;
  345. ggml_tensor * build_ffn(
  346. ggml_tensor * cur,
  347. ggml_tensor * up,
  348. ggml_tensor * up_b,
  349. ggml_tensor * up_s,
  350. ggml_tensor * gate,
  351. ggml_tensor * gate_b,
  352. ggml_tensor * gate_s,
  353. ggml_tensor * down,
  354. ggml_tensor * down_b,
  355. ggml_tensor * down_s,
  356. ggml_tensor * act_scales,
  357. llm_ffn_op_type type_op,
  358. llm_ffn_gate_type type_gate,
  359. int il) const;
  360. ggml_tensor * build_moe_ffn(
  361. ggml_tensor * cur,
  362. ggml_tensor * gate_inp,
  363. ggml_tensor * up_exps,
  364. ggml_tensor * gate_exps,
  365. ggml_tensor * down_exps,
  366. ggml_tensor * exp_probs_b,
  367. int64_t n_expert,
  368. int64_t n_expert_used,
  369. llm_ffn_op_type type_op,
  370. bool norm_w,
  371. bool scale_w,
  372. float w_scale,
  373. llama_expert_gating_func_type gating_op,
  374. int il) const;
  375. //
  376. // inputs
  377. //
  378. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  379. ggml_tensor * build_inp_pos() const;
  380. ggml_tensor * build_inp_attn_scale() const;
  381. ggml_tensor * build_inp_out_ids() const;
  382. ggml_tensor * build_inp_mean() const;
  383. ggml_tensor * build_inp_cls() const;
  384. ggml_tensor * build_inp_s_copy() const;
  385. ggml_tensor * build_inp_cross_embd() const;
  386. ggml_tensor * build_inp_pos_bucket_enc() const;
  387. ggml_tensor * build_inp_pos_bucket_dec() const;
  388. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  389. //
  390. // attention
  391. //
  392. ggml_tensor * build_attn_mha(
  393. ggml_cgraph * gf,
  394. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  395. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  396. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  397. ggml_tensor * kq_b,
  398. ggml_tensor * kq_mask,
  399. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  400. float kq_scale) const;
  401. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  402. ggml_tensor * build_attn(
  403. llm_graph_input_attn_no_cache * inp,
  404. ggml_cgraph * gf,
  405. ggml_tensor * wo,
  406. ggml_tensor * wo_b,
  407. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  408. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  409. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  410. ggml_tensor * kq_b,
  411. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  412. float kq_scale,
  413. int il) const;
  414. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  415. ggml_tensor * build_attn(
  416. llm_graph_input_attn_kv_unified * inp,
  417. ggml_cgraph * gf,
  418. ggml_tensor * wo,
  419. ggml_tensor * wo_b,
  420. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  421. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  422. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  423. ggml_tensor * kq_b,
  424. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  425. float kq_scale,
  426. int il) const;
  427. llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
  428. ggml_tensor * build_attn(
  429. llm_graph_input_attn_kv_unified_iswa * inp,
  430. ggml_cgraph * gf,
  431. ggml_tensor * wo,
  432. ggml_tensor * wo_b,
  433. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  434. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  435. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  436. ggml_tensor * kq_b,
  437. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  438. float kq_scale,
  439. int il) const;
  440. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  441. ggml_tensor * build_attn(
  442. llm_graph_input_attn_cross * inp,
  443. ggml_cgraph * gf,
  444. ggml_tensor * wo,
  445. ggml_tensor * wo_b,
  446. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  447. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  448. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  449. ggml_tensor * kq_b,
  450. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  451. float kq_scale,
  452. int il) const;
  453. //
  454. // recurrent
  455. //
  456. ggml_tensor * build_recurrent_state(
  457. ggml_cgraph * gf,
  458. ggml_tensor * s,
  459. ggml_tensor * state_copy,
  460. int32_t state_size,
  461. int32_t n_seqs,
  462. bool avoid_copies = false) const;
  463. ggml_tensor * build_rwkv_token_shift_load(
  464. ggml_cgraph * gf,
  465. ggml_tensor * state_copy,
  466. const llama_ubatch & ubatch,
  467. int il) const;
  468. ggml_tensor * build_rwkv_token_shift_store(
  469. ggml_tensor * token_shift,
  470. const llama_ubatch & ubatch,
  471. int il) const;
  472. //
  473. // pooling
  474. //
  475. void build_pooling(
  476. ggml_cgraph * gf,
  477. ggml_tensor * cls,
  478. ggml_tensor * cls_b,
  479. ggml_tensor * cls_out,
  480. ggml_tensor * cls_out_b) const;
  481. };
  482. // TODO: better name
  483. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);