llama-graph.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. class llama_memory_i;
  16. class llama_kv_cache_unified;
  17. // certain models (typically multi-modal) can produce different types of graphs
  18. enum llm_graph_type {
  19. LLM_GRAPH_TYPE_DEFAULT,
  20. LLM_GRAPH_TYPE_ENCODER,
  21. LLM_GRAPH_TYPE_DECODER,
  22. };
  23. enum llm_ffn_op_type {
  24. LLM_FFN_SILU,
  25. LLM_FFN_GELU,
  26. LLM_FFN_RELU,
  27. LLM_FFN_RELU_SQR,
  28. LLM_FFN_SWIGLU,
  29. };
  30. enum llm_ffn_gate_type {
  31. LLM_FFN_SEQ,
  32. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  33. };
  34. enum llm_norm_type {
  35. LLM_NORM,
  36. LLM_NORM_RMS,
  37. LLM_NORM_GROUP,
  38. };
  39. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  40. struct llama_cross {
  41. // the output embeddings from the encoder as a ggml tensor
  42. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  43. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  44. //ggml_tensor * t_embd = nullptr;
  45. int64_t n_embd = 0;
  46. int64_t n_enc = 0;
  47. // embeddings data copied to host memory (tmp)
  48. std::vector<float> v_embd;
  49. // needed to construct the cross-attention mask in the decoder
  50. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  51. };
  52. //
  53. // llm_graph_input
  54. //
  55. class llm_graph_input_i {
  56. public:
  57. virtual ~llm_graph_input_i() = default;
  58. virtual void set_input(const llama_ubatch * ubatch) = 0;
  59. };
  60. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  61. class llm_graph_input_embd : public llm_graph_input_i {
  62. public:
  63. llm_graph_input_embd() = default;
  64. virtual ~llm_graph_input_embd() = default;
  65. void set_input(const llama_ubatch * ubatch) override;
  66. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  67. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  68. };
  69. class llm_graph_input_pos : public llm_graph_input_i {
  70. public:
  71. llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
  72. virtual ~llm_graph_input_pos() = default;
  73. void set_input(const llama_ubatch * ubatch) override;
  74. ggml_tensor * pos = nullptr; // I32 [n_batch]
  75. const int64_t n_pos_per_token = 1;
  76. };
  77. // temperature tuning, used by llama4
  78. class llm_graph_input_attn_temp : public llm_graph_input_i {
  79. public:
  80. llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
  81. : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
  82. virtual ~llm_graph_input_attn_temp() = default;
  83. void set_input(const llama_ubatch * ubatch) override;
  84. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  85. const int64_t n_pos_per_token = 1;
  86. const uint32_t n_attn_temp_floor_scale;
  87. const float f_attn_temp_scale;
  88. };
  89. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  90. public:
  91. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  92. virtual ~llm_graph_input_pos_bucket() = default;
  93. void set_input(const llama_ubatch * ubatch) override;
  94. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  95. const llama_hparams & hparams;
  96. };
  97. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  98. public:
  99. llm_graph_input_pos_bucket_kv(
  100. const llama_hparams & hparams,
  101. const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
  102. virtual ~llm_graph_input_pos_bucket_kv() = default;
  103. void set_input(const llama_ubatch * ubatch) override;
  104. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  105. const llama_hparams & hparams;
  106. const llama_kv_cache_unified * kv_self;
  107. };
  108. class llm_graph_input_out_ids : public llm_graph_input_i {
  109. public:
  110. llm_graph_input_out_ids(
  111. const llama_hparams & hparams,
  112. const llama_cparams & cparams,
  113. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  114. virtual ~llm_graph_input_out_ids() = default;
  115. void set_input(const llama_ubatch * ubatch) override;
  116. ggml_tensor * out_ids; // I32 [n_outputs]
  117. const llama_hparams & hparams;
  118. const llama_cparams & cparams;
  119. const int32_t n_outputs;
  120. };
  121. class llm_graph_input_mean : public llm_graph_input_i {
  122. public:
  123. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  124. virtual ~llm_graph_input_mean() = default;
  125. void set_input(const llama_ubatch * ubatch) override;
  126. ggml_tensor * mean; // F32 [n_batch, n_batch]
  127. const llama_cparams & cparams;
  128. };
  129. class llm_graph_input_cls : public llm_graph_input_i {
  130. public:
  131. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  132. virtual ~llm_graph_input_cls() = default;
  133. void set_input(const llama_ubatch * ubatch) override;
  134. ggml_tensor * cls; // I32 [n_batch]
  135. const llama_cparams & cparams;
  136. };
  137. class llm_graph_input_s_copy : public llm_graph_input_i {
  138. public:
  139. llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  140. virtual ~llm_graph_input_s_copy() = default;
  141. void set_input(const llama_ubatch * ubatch) override;
  142. ggml_tensor * s_copy; // I32 [kv_size]
  143. const llama_kv_cache_unified * kv_self;
  144. };
  145. class llm_graph_input_s_mask : public llm_graph_input_i {
  146. public:
  147. llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  148. virtual ~llm_graph_input_s_mask() = default;
  149. void set_input(const llama_ubatch * ubatch) override;
  150. ggml_tensor * s_mask; // F32 [1, n_kv]
  151. const llama_kv_cache_unified * kv_self;
  152. };
  153. class llm_graph_input_cross_embd : public llm_graph_input_i {
  154. public:
  155. llm_graph_input_cross_embd(
  156. const llama_cross * cross) : cross(cross) {}
  157. virtual ~llm_graph_input_cross_embd() = default;
  158. void set_input(const llama_ubatch * ubatch) override;
  159. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  160. const llama_cross * cross;
  161. };
  162. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  163. public:
  164. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  165. hparams(hparams),
  166. cparams(cparams) {
  167. }
  168. ~llm_graph_input_attn_no_cache() = default;
  169. void set_input(const llama_ubatch * ubatch) override;
  170. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  171. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
  172. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
  173. const llama_hparams & hparams;
  174. const llama_cparams & cparams;
  175. };
  176. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  177. public:
  178. llm_graph_input_attn_kv_unified(
  179. const llama_hparams & hparams,
  180. const llama_cparams & cparams,
  181. const llama_kv_cache_unified * kv_self) :
  182. hparams(hparams),
  183. cparams(cparams),
  184. kv_self(kv_self) {
  185. }
  186. ~llm_graph_input_attn_kv_unified() = default;
  187. void set_input(const llama_ubatch * ubatch) override;
  188. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  189. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  190. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  191. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  192. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
  193. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
  194. const llama_hparams & hparams;
  195. const llama_cparams & cparams;
  196. const llama_kv_cache_unified * kv_self;
  197. };
  198. class llm_graph_input_attn_cross : public llm_graph_input_i {
  199. public:
  200. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  201. ~llm_graph_input_attn_cross() = default;
  202. void set_input(const llama_ubatch * ubatch) override;
  203. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  204. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
  205. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
  206. const llama_cross * cross = nullptr;
  207. };
  208. //
  209. // llm_graph_result
  210. //
  211. // these objects deliver the result from the graph build process back to the llama_context
  212. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  213. // specific data, by calling the set_inputs() method
  214. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  215. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  216. class llm_graph_result_i {
  217. public:
  218. virtual ~llm_graph_result_i() = default;
  219. virtual ggml_tensor * get_logits() = 0;
  220. virtual ggml_tensor * get_embd() = 0;
  221. virtual ggml_tensor * get_embd_pooled() = 0;
  222. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  223. };
  224. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  225. class llm_graph_result : public llm_graph_result_i {
  226. public:
  227. virtual ~llm_graph_result() = default;
  228. ggml_tensor * get_logits() override { return t_logits; }
  229. ggml_tensor * get_embd() override { return t_embd; }
  230. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  231. void set_inputs(const llama_ubatch * ubatch) override {
  232. for (auto & input : inputs) {
  233. input->set_input(ubatch);
  234. }
  235. }
  236. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  237. inputs.emplace_back(std::move(input));
  238. return inputs.back().get();
  239. }
  240. // important graph nodes
  241. ggml_tensor * t_logits = nullptr;
  242. ggml_tensor * t_embd = nullptr;
  243. ggml_tensor * t_embd_pooled = nullptr;
  244. std::vector<llm_graph_input_ptr> inputs;
  245. };
  246. //
  247. // llm_graph_context
  248. //
  249. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  250. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  251. struct llm_graph_params {
  252. ggml_context * ctx;
  253. const llm_arch arch;
  254. const llama_hparams & hparams;
  255. const llama_cparams & cparams;
  256. const llama_ubatch & ubatch;
  257. ggml_backend_sched * sched;
  258. ggml_backend * backend_cpu;
  259. const llama_adapter_cvec * cvec;
  260. const llama_adapter_loras * loras;
  261. const llama_memory_i * memory;
  262. const llama_cross * cross;
  263. int32_t n_outputs;
  264. const llm_graph_cb & cb;
  265. };
  266. struct llm_graph_context {
  267. const llm_arch arch;
  268. const llama_hparams & hparams;
  269. const llama_cparams & cparams;
  270. const llama_ubatch & ubatch;
  271. const int64_t n_embd;
  272. const int64_t n_layer;
  273. const int64_t n_rot;
  274. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  275. const int64_t n_ctx_per_seq;
  276. const int64_t n_head;
  277. const int64_t n_head_kv;
  278. const int64_t n_embd_head_k;
  279. const int64_t n_embd_k_gqa;
  280. const int64_t n_embd_head_v;
  281. const int64_t n_embd_v_gqa;
  282. const int64_t n_expert;
  283. const int64_t n_expert_used;
  284. const float freq_base;
  285. const float freq_scale;
  286. const float ext_factor;
  287. const float attn_factor;
  288. const float beta_fast;
  289. const float beta_slow;
  290. const float norm_eps;
  291. const float norm_rms_eps;
  292. const int32_t n_tokens;
  293. const int32_t n_outputs;
  294. const int32_t n_ctx_orig; // yarn
  295. const enum llama_pooling_type pooling_type;
  296. const enum llama_rope_type rope_type;
  297. ggml_context * ctx0 = nullptr;
  298. ggml_backend_sched * sched;
  299. ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  300. const llama_adapter_cvec * cvec;
  301. const llama_adapter_loras * loras;
  302. const llama_memory_i * memory;
  303. const llama_cross * cross;
  304. const llm_graph_cb & cb_func;
  305. std::unique_ptr<llm_graph_result> res;
  306. llm_graph_context(const llm_graph_params & params);
  307. int64_t n_pos_per_token() const;
  308. void cb(ggml_tensor * cur, const char * name, int il) const;
  309. //
  310. // common
  311. //
  312. ggml_tensor * build_cvec(
  313. ggml_tensor * cur,
  314. int il) const;
  315. // do mat_mul, while optionally apply lora
  316. ggml_tensor * build_lora_mm(
  317. ggml_tensor * w,
  318. ggml_tensor * cur) const;
  319. // do mat_mul_id, while optionally apply lora
  320. ggml_tensor * build_lora_mm_id(
  321. ggml_tensor * w, // ggml_tensor * as
  322. ggml_tensor * cur, // ggml_tensor * b
  323. ggml_tensor * ids) const;
  324. ggml_tensor * build_norm(
  325. ggml_tensor * cur,
  326. ggml_tensor * mw,
  327. ggml_tensor * mb,
  328. llm_norm_type type,
  329. int il) const;
  330. ggml_tensor * build_ffn(
  331. ggml_tensor * cur,
  332. ggml_tensor * up,
  333. ggml_tensor * up_b,
  334. ggml_tensor * up_s,
  335. ggml_tensor * gate,
  336. ggml_tensor * gate_b,
  337. ggml_tensor * gate_s,
  338. ggml_tensor * down,
  339. ggml_tensor * down_b,
  340. ggml_tensor * down_s,
  341. ggml_tensor * act_scales,
  342. llm_ffn_op_type type_op,
  343. llm_ffn_gate_type type_gate,
  344. int il) const;
  345. ggml_tensor * build_moe_ffn(
  346. ggml_tensor * cur,
  347. ggml_tensor * gate_inp,
  348. ggml_tensor * up_exps,
  349. ggml_tensor * gate_exps,
  350. ggml_tensor * down_exps,
  351. ggml_tensor * exp_probs_b,
  352. int64_t n_expert,
  353. int64_t n_expert_used,
  354. llm_ffn_op_type type_op,
  355. bool norm_w,
  356. bool scale_w,
  357. float w_scale,
  358. llama_expert_gating_func_type gating_op,
  359. int il) const;
  360. //
  361. // inputs
  362. //
  363. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  364. ggml_tensor * build_inp_pos() const;
  365. ggml_tensor * build_inp_attn_scale() const;
  366. ggml_tensor * build_inp_out_ids() const;
  367. ggml_tensor * build_inp_mean() const;
  368. ggml_tensor * build_inp_cls() const;
  369. ggml_tensor * build_inp_s_copy() const;
  370. ggml_tensor * build_inp_s_mask() const;
  371. ggml_tensor * build_inp_cross_embd() const;
  372. ggml_tensor * build_inp_pos_bucket_enc() const;
  373. ggml_tensor * build_inp_pos_bucket_dec() const;
  374. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  375. //
  376. // attention
  377. //
  378. ggml_tensor * build_attn_mha(
  379. ggml_cgraph * gf,
  380. ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
  381. ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
  382. ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
  383. ggml_tensor * kq_b,
  384. ggml_tensor * kq_mask,
  385. bool v_trans,
  386. float kq_scale) const;
  387. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  388. ggml_tensor * build_attn(
  389. llm_graph_input_attn_no_cache * inp,
  390. ggml_cgraph * gf,
  391. ggml_tensor * wo,
  392. ggml_tensor * wo_b,
  393. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  394. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  395. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  396. ggml_tensor * kq_b,
  397. float kq_scale,
  398. int il) const;
  399. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  400. ggml_tensor * build_attn(
  401. llm_graph_input_attn_kv_unified * inp,
  402. ggml_cgraph * gf,
  403. ggml_tensor * wo,
  404. ggml_tensor * wo_b,
  405. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  406. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  407. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  408. ggml_tensor * kq_b,
  409. float kq_scale,
  410. int il) const;
  411. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  412. ggml_tensor * build_attn(
  413. llm_graph_input_attn_cross * inp,
  414. ggml_cgraph * gf,
  415. ggml_tensor * wo,
  416. ggml_tensor * wo_b,
  417. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  418. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  419. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  420. ggml_tensor * kq_b,
  421. float kq_scale,
  422. int il) const;
  423. //
  424. // recurrent
  425. //
  426. ggml_tensor * build_copy_mask_state(
  427. ggml_cgraph * gf,
  428. ggml_tensor * s,
  429. ggml_tensor * state_copy,
  430. ggml_tensor * state_mask,
  431. int32_t n_state,
  432. int32_t n_seqs) const;
  433. ggml_tensor * build_rwkv_token_shift_load(
  434. ggml_cgraph * gf,
  435. ggml_tensor * state_copy,
  436. ggml_tensor * state_mask,
  437. const llama_ubatch & ubatch,
  438. int il) const;
  439. ggml_tensor * build_rwkv_token_shift_store(
  440. ggml_tensor * token_shift,
  441. const llama_ubatch & ubatch,
  442. int il) const;
  443. //
  444. // pooling
  445. //
  446. void build_pooling(
  447. ggml_cgraph * gf,
  448. ggml_tensor * cls,
  449. ggml_tensor * cls_b,
  450. ggml_tensor * cls_out,
  451. ggml_tensor * cls_out_b) const;
  452. };