llama-graph.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-hparams.h"
  4. #include "llama-adapter.h"
  5. #include <cstdint>
  6. #include <vector>
  7. #include <memory>
  8. #include <set>
  9. #include <functional>
  10. struct ggml_cgraph;
  11. struct ggml_context;
  12. struct ggml_tensor;
  13. struct llama_ubatch;
  14. struct llama_cparams;
  15. class llama_memory_i;
  16. class llama_kv_cache_unified;
  17. // certain models (typically multi-modal) can produce different types of graphs
  18. enum llm_graph_type {
  19. LLM_GRAPH_TYPE_DEFAULT,
  20. LLM_GRAPH_TYPE_ENCODER,
  21. LLM_GRAPH_TYPE_DECODER,
  22. };
  23. enum llm_ffn_op_type {
  24. LLM_FFN_SILU,
  25. LLM_FFN_GELU,
  26. LLM_FFN_RELU,
  27. LLM_FFN_RELU_SQR,
  28. LLM_FFN_SWIGLU,
  29. };
  30. enum llm_ffn_gate_type {
  31. LLM_FFN_SEQ,
  32. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  33. };
  34. enum llm_norm_type {
  35. LLM_NORM,
  36. LLM_NORM_RMS,
  37. LLM_NORM_GROUP,
  38. };
  39. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  40. struct llama_cross {
  41. // the output embeddings from the encoder as a ggml tensor
  42. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  43. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  44. //ggml_tensor * t_embd = nullptr;
  45. int64_t n_embd = 0;
  46. int64_t n_enc = 0;
  47. // embeddings data copied to host memory (tmp)
  48. std::vector<float> v_embd;
  49. // needed to construct the cross-attention mask in the decoder
  50. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  51. };
  52. //
  53. // llm_graph_input
  54. //
  55. class llm_graph_input_i {
  56. public:
  57. virtual ~llm_graph_input_i() = default;
  58. virtual void set_input(const llama_ubatch * ubatch) = 0;
  59. };
  60. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  61. class llm_graph_input_embd : public llm_graph_input_i {
  62. public:
  63. llm_graph_input_embd() = default;
  64. virtual ~llm_graph_input_embd() = default;
  65. void set_input(const llama_ubatch * ubatch) override;
  66. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  67. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  68. };
  69. class llm_graph_input_pos : public llm_graph_input_i {
  70. public:
  71. llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
  72. virtual ~llm_graph_input_pos() = default;
  73. void set_input(const llama_ubatch * ubatch) override;
  74. ggml_tensor * pos = nullptr; // I32 [n_batch]
  75. const int64_t n_pos_per_token = 1;
  76. };
  77. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  78. public:
  79. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  80. virtual ~llm_graph_input_pos_bucket() = default;
  81. void set_input(const llama_ubatch * ubatch) override;
  82. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  83. const llama_hparams & hparams;
  84. };
  85. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  86. public:
  87. llm_graph_input_pos_bucket_kv(
  88. const llama_hparams & hparams,
  89. const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
  90. virtual ~llm_graph_input_pos_bucket_kv() = default;
  91. void set_input(const llama_ubatch * ubatch) override;
  92. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  93. const llama_hparams & hparams;
  94. const llama_kv_cache_unified * kv_self;
  95. };
  96. class llm_graph_input_out_ids : public llm_graph_input_i {
  97. public:
  98. llm_graph_input_out_ids(
  99. const llama_hparams & hparams,
  100. const llama_cparams & cparams,
  101. int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  102. virtual ~llm_graph_input_out_ids() = default;
  103. void set_input(const llama_ubatch * ubatch) override;
  104. ggml_tensor * out_ids; // I32 [n_outputs]
  105. const llama_hparams & hparams;
  106. const llama_cparams & cparams;
  107. const int32_t n_outputs;
  108. };
  109. class llm_graph_input_mean : public llm_graph_input_i {
  110. public:
  111. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  112. virtual ~llm_graph_input_mean() = default;
  113. void set_input(const llama_ubatch * ubatch) override;
  114. ggml_tensor * mean; // F32 [n_batch, n_batch]
  115. const llama_cparams & cparams;
  116. };
  117. class llm_graph_input_cls : public llm_graph_input_i {
  118. public:
  119. llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
  120. virtual ~llm_graph_input_cls() = default;
  121. void set_input(const llama_ubatch * ubatch) override;
  122. ggml_tensor * cls; // I32 [n_batch]
  123. const llama_cparams & cparams;
  124. };
  125. class llm_graph_input_s_copy : public llm_graph_input_i {
  126. public:
  127. llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  128. virtual ~llm_graph_input_s_copy() = default;
  129. void set_input(const llama_ubatch * ubatch) override;
  130. ggml_tensor * s_copy; // I32 [kv_size]
  131. const llama_kv_cache_unified * kv_self;
  132. };
  133. class llm_graph_input_s_mask : public llm_graph_input_i {
  134. public:
  135. llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  136. virtual ~llm_graph_input_s_mask() = default;
  137. void set_input(const llama_ubatch * ubatch) override;
  138. ggml_tensor * s_mask; // F32 [1, n_kv]
  139. const llama_kv_cache_unified * kv_self;
  140. };
  141. class llm_graph_input_cross_embd : public llm_graph_input_i {
  142. public:
  143. llm_graph_input_cross_embd(
  144. const llama_cross * cross) : cross(cross) {}
  145. virtual ~llm_graph_input_cross_embd() = default;
  146. void set_input(const llama_ubatch * ubatch) override;
  147. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  148. const llama_cross * cross;
  149. };
  150. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  151. public:
  152. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  153. hparams(hparams),
  154. cparams(cparams) {
  155. }
  156. ~llm_graph_input_attn_no_cache() = default;
  157. void set_input(const llama_ubatch * ubatch) override;
  158. ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
  159. ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
  160. ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
  161. const llama_hparams & hparams;
  162. const llama_cparams & cparams;
  163. };
  164. class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
  165. public:
  166. llm_graph_input_attn_kv_unified(
  167. const llama_hparams & hparams,
  168. const llama_cparams & cparams,
  169. const llama_kv_cache_unified * kv_self) :
  170. hparams(hparams),
  171. cparams(cparams),
  172. kv_self(kv_self) {
  173. }
  174. ~llm_graph_input_attn_kv_unified() = default;
  175. void set_input(const llama_ubatch * ubatch) override;
  176. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  177. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  178. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
  179. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
  180. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
  181. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
  182. const llama_hparams & hparams;
  183. const llama_cparams & cparams;
  184. const llama_kv_cache_unified * kv_self;
  185. };
  186. class llm_graph_input_attn_cross : public llm_graph_input_i {
  187. public:
  188. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  189. ~llm_graph_input_attn_cross() = default;
  190. void set_input(const llama_ubatch * ubatch) override;
  191. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  192. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
  193. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
  194. const llama_cross * cross = nullptr;
  195. };
  196. //
  197. // llm_graph_result
  198. //
  199. // these objects deliver the result from the graph build process back to the llama_context
  200. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  201. // specific data, by calling the set_inputs() method
  202. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  203. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  204. class llm_graph_result_i {
  205. public:
  206. virtual ~llm_graph_result_i() = default;
  207. virtual ggml_tensor * get_logits() = 0;
  208. virtual ggml_tensor * get_embd() = 0;
  209. virtual ggml_tensor * get_embd_pooled() = 0;
  210. virtual void set_inputs(const llama_ubatch * ubatch) = 0;
  211. };
  212. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
  213. class llm_graph_result : public llm_graph_result_i {
  214. public:
  215. virtual ~llm_graph_result() = default;
  216. ggml_tensor * get_logits() override { return t_logits; }
  217. ggml_tensor * get_embd() override { return t_embd; }
  218. ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
  219. void set_inputs(const llama_ubatch * ubatch) override {
  220. for (auto & input : inputs) {
  221. input->set_input(ubatch);
  222. }
  223. }
  224. llm_graph_input_i * add_input(llm_graph_input_ptr input) {
  225. inputs.emplace_back(std::move(input));
  226. return inputs.back().get();
  227. }
  228. // important graph nodes
  229. ggml_tensor * t_logits = nullptr;
  230. ggml_tensor * t_embd = nullptr;
  231. ggml_tensor * t_embd_pooled = nullptr;
  232. std::vector<llm_graph_input_ptr> inputs;
  233. };
  234. //
  235. // llm_graph_context
  236. //
  237. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  238. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  239. struct llm_graph_params {
  240. ggml_context * ctx;
  241. const llm_arch arch;
  242. const llama_hparams & hparams;
  243. const llama_cparams & cparams;
  244. const llama_ubatch & ubatch;
  245. ggml_backend_sched * sched;
  246. ggml_backend * backend_cpu;
  247. const llama_adapter_cvec * cvec;
  248. const llama_adapter_loras * loras;
  249. const llama_memory_i * memory;
  250. const llama_cross * cross;
  251. int32_t n_outputs;
  252. const llm_graph_cb & cb;
  253. };
  254. struct llm_graph_context {
  255. const llm_arch arch;
  256. const llama_hparams & hparams;
  257. const llama_cparams & cparams;
  258. const llama_ubatch & ubatch;
  259. const int64_t n_embd;
  260. const int64_t n_layer;
  261. const int64_t n_rot;
  262. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  263. const int64_t n_ctx_per_seq;
  264. const int64_t n_head;
  265. const int64_t n_head_kv;
  266. const int64_t n_embd_head_k;
  267. const int64_t n_embd_k_gqa;
  268. const int64_t n_embd_head_v;
  269. const int64_t n_embd_v_gqa;
  270. const int64_t n_expert;
  271. const int64_t n_expert_used;
  272. const float freq_base;
  273. const float freq_scale;
  274. const float ext_factor;
  275. const float attn_factor;
  276. const float beta_fast;
  277. const float beta_slow;
  278. const float norm_eps;
  279. const float norm_rms_eps;
  280. const int32_t n_tokens;
  281. const int32_t n_outputs;
  282. const int32_t n_ctx_orig; // yarn
  283. const enum llama_pooling_type pooling_type;
  284. const enum llama_rope_type rope_type;
  285. ggml_context * ctx0 = nullptr;
  286. ggml_backend_sched * sched;
  287. ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  288. const llama_adapter_cvec * cvec;
  289. const llama_adapter_loras * loras;
  290. const llama_memory_i * memory;
  291. const llama_cross * cross;
  292. const llm_graph_cb & cb_func;
  293. std::unique_ptr<llm_graph_result> res;
  294. llm_graph_context(const llm_graph_params & params);
  295. int64_t n_pos_per_token() const;
  296. void cb(ggml_tensor * cur, const char * name, int il) const;
  297. //
  298. // common
  299. //
  300. ggml_tensor * build_cvec(
  301. ggml_tensor * cur,
  302. int il) const;
  303. // do mat_mul, while optionally apply lora
  304. ggml_tensor * build_lora_mm(
  305. ggml_tensor * w,
  306. ggml_tensor * cur) const;
  307. // do mat_mul_id, while optionally apply lora
  308. ggml_tensor * build_lora_mm_id(
  309. ggml_tensor * w, // ggml_tensor * as
  310. ggml_tensor * cur, // ggml_tensor * b
  311. ggml_tensor * ids) const;
  312. ggml_tensor * build_norm(
  313. ggml_tensor * cur,
  314. ggml_tensor * mw,
  315. ggml_tensor * mb,
  316. llm_norm_type type,
  317. int il) const;
  318. ggml_tensor * build_ffn(
  319. ggml_tensor * cur,
  320. ggml_tensor * up,
  321. ggml_tensor * up_b,
  322. ggml_tensor * up_s,
  323. ggml_tensor * gate,
  324. ggml_tensor * gate_b,
  325. ggml_tensor * gate_s,
  326. ggml_tensor * down,
  327. ggml_tensor * down_b,
  328. ggml_tensor * down_s,
  329. ggml_tensor * act_scales,
  330. llm_ffn_op_type type_op,
  331. llm_ffn_gate_type type_gate,
  332. int il) const;
  333. ggml_tensor * build_moe_ffn(
  334. ggml_tensor * cur,
  335. ggml_tensor * gate_inp,
  336. ggml_tensor * up_exps,
  337. ggml_tensor * gate_exps,
  338. ggml_tensor * down_exps,
  339. ggml_tensor * exp_probs_b,
  340. int64_t n_expert,
  341. int64_t n_expert_used,
  342. llm_ffn_op_type type_op,
  343. bool norm_w,
  344. bool scale_w,
  345. float w_scale,
  346. llama_expert_gating_func_type gating_op,
  347. int il) const;
  348. //
  349. // inputs
  350. //
  351. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  352. ggml_tensor * build_inp_pos() const;
  353. ggml_tensor * build_inp_out_ids() const;
  354. ggml_tensor * build_inp_mean() const;
  355. ggml_tensor * build_inp_cls() const;
  356. ggml_tensor * build_inp_s_copy() const;
  357. ggml_tensor * build_inp_s_mask() const;
  358. ggml_tensor * build_inp_cross_embd() const;
  359. ggml_tensor * build_inp_pos_bucket_enc() const;
  360. ggml_tensor * build_inp_pos_bucket_dec() const;
  361. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  362. //
  363. // attention
  364. //
  365. ggml_tensor * build_attn_mha(
  366. ggml_cgraph * gf,
  367. ggml_tensor * q,
  368. ggml_tensor * k,
  369. ggml_tensor * v,
  370. ggml_tensor * kq_b,
  371. ggml_tensor * kq_mask,
  372. bool v_trans,
  373. float kq_scale) const;
  374. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  375. ggml_tensor * build_attn(
  376. llm_graph_input_attn_no_cache * inp,
  377. ggml_cgraph * gf,
  378. ggml_tensor * wo,
  379. ggml_tensor * wo_b,
  380. ggml_tensor * q_cur,
  381. ggml_tensor * k_cur,
  382. ggml_tensor * v_cur,
  383. ggml_tensor * kq_b,
  384. float kq_scale,
  385. int il) const;
  386. llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
  387. ggml_tensor * build_attn(
  388. llm_graph_input_attn_kv_unified * inp,
  389. ggml_cgraph * gf,
  390. ggml_tensor * wo,
  391. ggml_tensor * wo_b,
  392. ggml_tensor * q_cur,
  393. ggml_tensor * k_cur,
  394. ggml_tensor * v_cur,
  395. ggml_tensor * kq_b,
  396. float kq_scale,
  397. int il) const;
  398. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  399. ggml_tensor * build_attn(
  400. llm_graph_input_attn_cross * inp,
  401. ggml_cgraph * gf,
  402. ggml_tensor * wo,
  403. ggml_tensor * wo_b,
  404. ggml_tensor * q_cur,
  405. ggml_tensor * k_cur,
  406. ggml_tensor * v_cur,
  407. ggml_tensor * kq_b,
  408. float kq_scale,
  409. int il) const;
  410. //
  411. // recurrent
  412. //
  413. ggml_tensor * build_copy_mask_state(
  414. ggml_cgraph * gf,
  415. ggml_tensor * s,
  416. ggml_tensor * state_copy,
  417. ggml_tensor * state_mask,
  418. int32_t n_state,
  419. int32_t n_seqs) const;
  420. ggml_tensor * build_rwkv_token_shift_load(
  421. ggml_cgraph * gf,
  422. ggml_tensor * state_copy,
  423. ggml_tensor * state_mask,
  424. const llama_ubatch & ubatch,
  425. int il) const;
  426. ggml_tensor * build_rwkv_token_shift_store(
  427. ggml_tensor * token_shift,
  428. const llama_ubatch & ubatch,
  429. int il) const;
  430. //
  431. // pooling
  432. //
  433. void build_pooling(
  434. ggml_cgraph * gf,
  435. ggml_tensor * cls,
  436. ggml_tensor * cls_b,
  437. ggml_tensor * cls_out,
  438. ggml_tensor * cls_out_b) const;
  439. };