llama-graph.h 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. #pragma once
  2. #include "llama-arch.h"
  3. #include "llama-batch.h"
  4. #include "llama-hparams.h"
  5. #include "llama-adapter.h"
  6. #include <cstdint>
  7. #include <vector>
  8. #include <memory>
  9. #include <set>
  10. #include <functional>
  11. #include <map>
  12. struct ggml_cgraph;
  13. struct ggml_context;
  14. struct ggml_tensor;
  15. struct llama_cparams;
  16. struct llama_memory_context_i;
  17. class llama_kv_cache_context;
  18. class llama_kv_cache_iswa_context;
  19. class llama_memory_recurrent_context;
  20. class llama_memory_hybrid_context;
  21. class llama_memory_hybrid_iswa_context;
  22. // certain models (typically multi-modal) can produce different types of graphs
  23. enum llm_graph_type {
  24. LLM_GRAPH_TYPE_DEFAULT,
  25. LLM_GRAPH_TYPE_ENCODER,
  26. LLM_GRAPH_TYPE_DECODER,
  27. };
  28. enum llm_ffn_op_type {
  29. LLM_FFN_SILU,
  30. LLM_FFN_GELU,
  31. LLM_FFN_RELU,
  32. LLM_FFN_RELU_SQR,
  33. LLM_FFN_SWIGLU,
  34. LLM_FFN_GEGLU,
  35. LLM_FFN_REGLU,
  36. LLM_FFN_SWIGLU_OAI_MOE,
  37. };
  38. enum llm_ffn_gate_type {
  39. LLM_FFN_SEQ,
  40. LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
  41. };
  42. enum llm_norm_type {
  43. LLM_NORM,
  44. LLM_NORM_RMS,
  45. LLM_NORM_GROUP,
  46. };
  47. // TODO: tmp - need something better to pass the data from the encoder to the decoder
  48. struct llama_cross {
  49. // the output embeddings from the encoder as a ggml tensor
  50. // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
  51. // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
  52. //ggml_tensor * t_embd = nullptr;
  53. int64_t n_embd = 0;
  54. int64_t n_enc = 0;
  55. // embeddings data copied to host memory (tmp)
  56. std::vector<float> v_embd;
  57. // needed to construct the cross-attention mask in the decoder
  58. std::vector<std::set<llama_seq_id>> seq_ids_enc;
  59. };
  60. struct llm_graph_params;
  61. //
  62. // llm_graph_input
  63. //
  64. class llm_graph_input_i {
  65. public:
  66. llm_graph_input_i() {
  67. const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
  68. debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
  69. }
  70. virtual ~llm_graph_input_i() = default;
  71. virtual void set_input(const llama_ubatch * ubatch) = 0;
  72. // return true if the resulting input tensors using the provided graph parameters would be
  73. // the same as the previous input tensors that we have currently stored in the object
  74. virtual bool can_reuse(const llm_graph_params & params) {
  75. // returning false here by default will prevent from reusing the graph if the check
  76. // for the input type has not been implemented yet
  77. GGML_UNUSED(params);
  78. return false;
  79. }
  80. protected:
  81. // env: LLAMA_GRAPH_INPUT_DEBUG
  82. int debug = 0;
  83. };
  84. using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
  85. class llm_graph_input_embd : public llm_graph_input_i {
  86. public:
  87. llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
  88. virtual ~llm_graph_input_embd() = default;
  89. void set_input(const llama_ubatch * ubatch) override;
  90. bool can_reuse(const llm_graph_params & params) override;
  91. ggml_tensor * tokens = nullptr; // I32 [n_batch]
  92. ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
  93. const int64_t n_embd = 0;
  94. };
  95. class llm_graph_input_pos : public llm_graph_input_i {
  96. public:
  97. llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
  98. virtual ~llm_graph_input_pos() = default;
  99. void set_input(const llama_ubatch * ubatch) override;
  100. bool can_reuse(const llm_graph_params & params) override;
  101. ggml_tensor * pos = nullptr; // I32 [n_batch]
  102. const uint32_t n_pos_per_embd = 1;
  103. };
  104. // temperature tuning, used by llama4
  105. class llm_graph_input_attn_temp : public llm_graph_input_i {
  106. public:
  107. llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
  108. : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
  109. virtual ~llm_graph_input_attn_temp() = default;
  110. void set_input(const llama_ubatch * ubatch) override;
  111. ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
  112. const uint32_t n_attn_temp_floor_scale;
  113. const float f_attn_temp_scale;
  114. const float f_attn_temp_offset;
  115. };
  116. class llm_graph_input_pos_bucket : public llm_graph_input_i {
  117. public:
  118. llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
  119. virtual ~llm_graph_input_pos_bucket() = default;
  120. void set_input(const llama_ubatch * ubatch) override;
  121. ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
  122. const llama_hparams hparams;
  123. };
  124. class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
  125. public:
  126. llm_graph_input_pos_bucket_kv(
  127. const llama_hparams & hparams,
  128. const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
  129. virtual ~llm_graph_input_pos_bucket_kv() = default;
  130. void set_input(const llama_ubatch * ubatch) override;
  131. ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
  132. const llama_hparams hparams;
  133. const llama_kv_cache_context * mctx;
  134. };
  135. class llm_graph_input_out_ids : public llm_graph_input_i {
  136. public:
  137. llm_graph_input_out_ids(
  138. const llama_hparams & hparams,
  139. const llama_cparams & cparams,
  140. uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
  141. virtual ~llm_graph_input_out_ids() = default;
  142. void set_input(const llama_ubatch * ubatch) override;
  143. bool can_reuse(const llm_graph_params & params) override;
  144. ggml_tensor * out_ids; // I32 [n_outputs]
  145. const llama_hparams hparams;
  146. const llama_cparams cparams;
  147. const uint32_t n_outputs;
  148. };
  149. class llm_graph_input_mean : public llm_graph_input_i {
  150. public:
  151. llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
  152. virtual ~llm_graph_input_mean() = default;
  153. void set_input(const llama_ubatch * ubatch) override;
  154. ggml_tensor * mean; // F32 [n_batch, n_batch]
  155. const llama_cparams cparams;
  156. };
  157. class llm_graph_input_cls : public llm_graph_input_i {
  158. public:
  159. llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
  160. virtual ~llm_graph_input_cls() = default;
  161. void set_input(const llama_ubatch * ubatch) override;
  162. ggml_tensor * cls; // I32 [n_batch]
  163. const llama_cparams cparams;
  164. const llm_arch arch;
  165. };
  166. class llm_graph_input_rs : public llm_graph_input_i {
  167. public:
  168. llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
  169. virtual ~llm_graph_input_rs() = default;
  170. void set_input(const llama_ubatch * ubatch) override;
  171. bool can_reuse(const llm_graph_params & params) override;
  172. ggml_tensor * s_copy; // I32 [n_rs]
  173. // views of s_copy, computed once per graph
  174. // and shared across layers which use build_rs
  175. ggml_tensor * s_copy_main; // I32 [n_seqs]
  176. ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
  177. const llama_memory_recurrent_context * mctx;
  178. // used in view offsets, need to match for valid graph reuse
  179. uint32_t head;
  180. int32_t rs_z;
  181. };
  182. class llm_graph_input_cross_embd : public llm_graph_input_i {
  183. public:
  184. llm_graph_input_cross_embd(
  185. const llama_cross * cross) : cross(cross) {}
  186. virtual ~llm_graph_input_cross_embd() = default;
  187. void set_input(const llama_ubatch * ubatch) override;
  188. ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
  189. const llama_cross * cross;
  190. };
  191. class llm_graph_input_attn_no_cache : public llm_graph_input_i {
  192. public:
  193. llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
  194. hparams(hparams),
  195. cparams(cparams) {
  196. }
  197. ~llm_graph_input_attn_no_cache() = default;
  198. void set_input(const llama_ubatch * ubatch) override;
  199. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  200. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  201. // n_tokens == n_batch
  202. ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
  203. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
  204. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
  205. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
  206. const llama_hparams hparams;
  207. const llama_cparams cparams;
  208. };
  209. class llm_graph_input_attn_kv : public llm_graph_input_i {
  210. public:
  211. llm_graph_input_attn_kv(
  212. const llama_hparams & hparams,
  213. const llama_cparams & cparams,
  214. const llama_kv_cache_context * mctx) :
  215. hparams(hparams),
  216. cparams(cparams),
  217. mctx(mctx) {
  218. }
  219. ~llm_graph_input_attn_kv() = default;
  220. void set_input(const llama_ubatch * ubatch) override;
  221. bool can_reuse(const llm_graph_params & params) override;
  222. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  223. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  224. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  225. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  226. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  227. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  228. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  229. // note: these have to be copies because in order to be able to reuse a graph, its inputs
  230. // need to carry these parameters with them. otherwise, they can point to freed
  231. // llm_graph_params from a previous batch, causing stack-use-after-return
  232. const llama_hparams hparams;
  233. const llama_cparams cparams;
  234. const llama_kv_cache_context * mctx;
  235. };
  236. // V-less input for the KV cache
  237. // ref: https://github.com/ggml-org/llama.cpp/pull/19067
  238. class llm_graph_input_attn_k : public llm_graph_input_i {
  239. public:
  240. llm_graph_input_attn_k(
  241. const llama_hparams & hparams,
  242. const llama_cparams & cparams,
  243. const llama_kv_cache_context * mctx) :
  244. hparams(hparams),
  245. cparams(cparams),
  246. mctx(mctx) {
  247. }
  248. ~llm_graph_input_attn_k() = default;
  249. void set_input(const llama_ubatch * ubatch) override;
  250. bool can_reuse(const llm_graph_params & params) override;
  251. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  252. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  253. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  254. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  255. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  256. const llama_hparams hparams;
  257. const llama_cparams cparams;
  258. const llama_kv_cache_context * mctx;
  259. };
  260. class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
  261. public:
  262. llm_graph_input_attn_kv_iswa(
  263. const llama_hparams & hparams,
  264. const llama_cparams & cparams,
  265. const llama_kv_cache_iswa_context * mctx) :
  266. hparams(hparams),
  267. cparams(cparams),
  268. mctx(mctx) {
  269. }
  270. ~llm_graph_input_attn_kv_iswa() = default;
  271. void set_input(const llama_ubatch * ubatch) override;
  272. bool can_reuse(const llm_graph_params & params) override;
  273. ggml_tensor * get_k_idxs() const { return self_k_idxs; }
  274. ggml_tensor * get_v_idxs() const { return self_v_idxs; }
  275. ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
  276. ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
  277. ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
  278. ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
  279. ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
  280. ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  281. ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
  282. ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
  283. ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  284. ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  285. ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
  286. ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
  287. const llama_hparams hparams;
  288. const llama_cparams cparams;
  289. const llama_kv_cache_iswa_context * mctx;
  290. };
  291. class llm_graph_input_attn_cross : public llm_graph_input_i {
  292. public:
  293. llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
  294. ~llm_graph_input_attn_cross() = default;
  295. void set_input(const llama_ubatch * ubatch) override;
  296. ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
  297. ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  298. ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
  299. const llama_cross * cross = nullptr;
  300. };
  301. class llm_graph_input_mem_hybrid : public llm_graph_input_i {
  302. public:
  303. llm_graph_input_mem_hybrid(
  304. const llama_cparams & cparams,
  305. std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
  306. std::unique_ptr<llm_graph_input_rs> inp_rs,
  307. const llama_memory_hybrid_context * mctx) :
  308. inp_attn(std::move(inp_attn)),
  309. inp_rs(std::move(inp_rs)),
  310. cparams(cparams),
  311. mctx(mctx) { }
  312. virtual ~llm_graph_input_mem_hybrid() = default;
  313. void set_input(const llama_ubatch * ubatch) override;
  314. bool can_reuse(const llm_graph_params & params) override;
  315. std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
  316. std::unique_ptr<llm_graph_input_rs> inp_rs;
  317. llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
  318. llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
  319. const llama_cparams cparams;
  320. const llama_memory_hybrid_context * mctx;
  321. };
  322. class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
  323. public:
  324. llm_graph_input_mem_hybrid_iswa(
  325. const llama_cparams & cparams,
  326. std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
  327. std::unique_ptr<llm_graph_input_rs> inp_rs,
  328. const llama_memory_hybrid_iswa_context * mctx) :
  329. inp_attn(std::move(inp_attn)),
  330. inp_rs(std::move(inp_rs)),
  331. cparams(cparams),
  332. mctx(mctx) { }
  333. virtual ~llm_graph_input_mem_hybrid_iswa() = default;
  334. void set_input(const llama_ubatch * ubatch) override;
  335. bool can_reuse(const llm_graph_params & params) override;
  336. std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
  337. std::unique_ptr<llm_graph_input_rs> inp_rs;
  338. llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
  339. llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
  340. const llama_cparams cparams;
  341. const llama_memory_hybrid_iswa_context * mctx;
  342. };
  343. class llm_graph_input_sampling : public llm_graph_input_i {
  344. public:
  345. llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
  346. samplers(std::move(samplers)) { }
  347. virtual ~llm_graph_input_sampling() = default;
  348. void set_input(const llama_ubatch * ubatch) override;
  349. bool can_reuse(const llm_graph_params & params) override;
  350. std::map<llama_seq_id, llama_sampler *> samplers;
  351. };
  352. //
  353. // llm_graph_result
  354. //
  355. // these objects deliver the result from the graph build process back to the llama_context
  356. // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
  357. // specific data, by calling the set_inputs() method
  358. // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
  359. // these are used by the llama_context to extact the relevant data, based on the compute parameters
  360. // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
  361. using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
  362. class llm_graph_result;
  363. struct llm_graph_params {
  364. llm_arch arch = LLM_ARCH_UNKNOWN;
  365. llama_hparams hparams;
  366. llama_cparams cparams;
  367. llama_ubatch ubatch; // note: intentionally make a copy
  368. llm_graph_type gtype;
  369. ggml_backend_sched_t sched;
  370. ggml_backend_t backend_cpu;
  371. const llama_adapter_cvec * cvec;
  372. const llama_adapter_loras * loras;
  373. const llama_memory_context_i * mctx;
  374. const llama_cross * cross;
  375. std::map<llama_seq_id, llama_sampler *> samplers;
  376. static bool samplers_equal(
  377. const std::map<llama_seq_id, llama_sampler *> & lhs,
  378. const std::map<llama_seq_id, llama_sampler *> & rhs) {
  379. if (lhs.size() != rhs.size()) {
  380. return false;
  381. }
  382. for (const auto & [seq_id, sampler] : lhs) {
  383. auto it = rhs.find(seq_id);
  384. if (it == rhs.end() || it->second != sampler) {
  385. return false;
  386. }
  387. }
  388. return true;
  389. }
  390. uint32_t n_outputs;
  391. llm_graph_cb cb;
  392. llm_graph_result * res;
  393. // return true if the "other" params would result in a graph with the same topology as with the current params
  394. // having the same topology allows us to reuse the graph in some cases
  395. bool allow_reuse(const llm_graph_params & other) const {
  396. // first check the ubatch
  397. bool can_reuse_ubatch =
  398. ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
  399. ubatch.n_tokens == other.ubatch.n_tokens &&
  400. ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
  401. ubatch.n_seqs == other.ubatch.n_seqs &&
  402. ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
  403. (
  404. (!ubatch.token && !other.ubatch.token) ||
  405. (!ubatch.embd && !other.ubatch.embd)
  406. );
  407. // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
  408. // the reason is because the set of attention streams would be different for different sequences
  409. if (can_reuse_ubatch && ubatch.equal_seqs()) {
  410. if (!ubatch.data) {
  411. // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
  412. // therefore we cannot perform the sequence id check. normally should never happen
  413. can_reuse_ubatch = false;
  414. } else {
  415. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  416. can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
  417. }
  418. }
  419. }
  420. if (!can_reuse_ubatch) {
  421. return false;
  422. }
  423. if (n_outputs != other.n_outputs) {
  424. return false;
  425. }
  426. if (!samplers_equal(samplers, other.samplers)) {
  427. return false;
  428. }
  429. if (samplers.size() > 0) {
  430. if (!ubatch.data || !other.ubatch.data) {
  431. return false;
  432. }
  433. // check that the outputs are the same for all samplers
  434. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  435. if (ubatch.output[i] != other.ubatch.output[i] ||
  436. ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
  437. return false;
  438. }
  439. }
  440. }
  441. return
  442. cparams.embeddings == other.cparams.embeddings &&
  443. cparams.causal_attn == other.cparams.causal_attn &&
  444. arch == other.arch &&
  445. gtype == other.gtype &&
  446. cvec == other.cvec &&
  447. loras == other.loras &&
  448. cross == other.cross;
  449. }
  450. };
  451. class llm_graph_result {
  452. public:
  453. llm_graph_result(int64_t max_nodes);
  454. virtual ~llm_graph_result() = default;
  455. ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
  456. ggml_tensor * get_logits() const { return t_logits; }
  457. ggml_tensor * get_embd() const { return t_embd; }
  458. ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
  459. ggml_cgraph * get_gf() const { return gf; }
  460. ggml_context * get_ctx() const { return ctx_compute.get(); }
  461. int64_t get_max_nodes() const;
  462. void reset();
  463. void set_inputs(const llama_ubatch * ubatch);
  464. void set_outputs();
  465. // try to update the existing graph result using the new graph parameters in order to reuse it
  466. // this can only be done if we determine that the resulting graph using the new graph parameters
  467. // would be identical to the existing graph. in that case, we simply have to update the memory
  468. // contexts of the input tensors of the graph and we can reuse it for another computation
  469. // return true if the graph was updated and can be reused
  470. bool can_reuse(const llm_graph_params & params);
  471. llm_graph_input_i * add_input(llm_graph_input_ptr input);
  472. void set_params(const llm_graph_params & params);
  473. // important graph nodes
  474. ggml_tensor * t_inp_tokens = nullptr;
  475. ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
  476. ggml_tensor * t_logits = nullptr;
  477. ggml_tensor * t_embd = nullptr;
  478. ggml_tensor * t_embd_pooled = nullptr;
  479. std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
  480. std::map<llama_seq_id, ggml_tensor*> t_candidates;
  481. std::map<llama_seq_id, ggml_tensor*> t_sampled;
  482. std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
  483. std::vector<llm_graph_input_ptr> inputs;
  484. ggml_context_ptr ctx_compute;
  485. // memory buffers used to evaluate the model
  486. std::vector<uint8_t> buf_compute_meta;
  487. ggml_cgraph * gf;
  488. int64_t max_nodes;
  489. private:
  490. // keep a copy of the previous graph parameters
  491. // we will use this to determine whether the graph can be reused by comparing them with the new parameters
  492. // note: these are updated after constructing the new graph
  493. llm_graph_params params;
  494. // env: LLAMA_GRAPH_RESULT_DEBUG
  495. int debug = 0;
  496. };
  497. using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
  498. //
  499. // llm_graph_context
  500. //
  501. // used in build_rs to properly order writes and avoid unnecessary copies
  502. using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
  503. struct llm_graph_context {
  504. const llm_arch arch;
  505. const llama_hparams & hparams;
  506. const llama_cparams & cparams;
  507. const llama_ubatch & ubatch;
  508. const int64_t n_embd;
  509. const int64_t n_layer;
  510. const int64_t n_rot;
  511. const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
  512. const int64_t n_head;
  513. const int64_t n_head_kv;
  514. const int64_t n_embd_head_k;
  515. const int64_t n_embd_k_gqa;
  516. const int64_t n_embd_head_v;
  517. const int64_t n_embd_v_gqa;
  518. const int64_t n_expert;
  519. const int64_t n_expert_used;
  520. const float freq_base;
  521. const float freq_scale;
  522. const float ext_factor;
  523. const float attn_factor;
  524. const float beta_fast;
  525. const float beta_slow;
  526. const float norm_eps;
  527. const float norm_rms_eps;
  528. const int64_t n_tokens;
  529. const int64_t n_outputs;
  530. const int32_t n_ctx_orig; // yarn
  531. const enum llama_pooling_type pooling_type;
  532. const enum llama_rope_type rope_type;
  533. ggml_backend_sched_t sched;
  534. ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
  535. const llama_adapter_cvec * cvec;
  536. const llama_adapter_loras * loras;
  537. const llama_memory_context_i * mctx;
  538. const llama_cross * cross;
  539. std::map<llama_seq_id, llama_sampler *> samplers;
  540. const llm_graph_cb & cb_func;
  541. llm_graph_result * res;
  542. ggml_context * ctx0 = nullptr;
  543. ggml_cgraph * gf = nullptr;
  544. llm_graph_context(const llm_graph_params & params);
  545. virtual ~llm_graph_context() = default;
  546. void cb(ggml_tensor * cur, const char * name, int il) const;
  547. //
  548. // common
  549. //
  550. ggml_tensor * build_cvec(
  551. ggml_tensor * cur,
  552. int il) const;
  553. // do mat_mul, while optionally apply lora
  554. ggml_tensor * build_lora_mm(
  555. ggml_tensor * w,
  556. ggml_tensor * cur) const;
  557. // do mat_mul_id, while optionally apply lora
  558. ggml_tensor * build_lora_mm_id(
  559. ggml_tensor * w, // ggml_tensor * as
  560. ggml_tensor * cur, // ggml_tensor * b
  561. ggml_tensor * ids) const;
  562. ggml_tensor * build_norm(
  563. ggml_tensor * cur,
  564. ggml_tensor * mw,
  565. ggml_tensor * mb,
  566. llm_norm_type type,
  567. int il) const;
  568. ggml_tensor * build_ffn(
  569. ggml_tensor * cur,
  570. ggml_tensor * up,
  571. ggml_tensor * up_b,
  572. ggml_tensor * up_s,
  573. ggml_tensor * gate,
  574. ggml_tensor * gate_b,
  575. ggml_tensor * gate_s,
  576. ggml_tensor * down,
  577. ggml_tensor * down_b,
  578. ggml_tensor * down_s,
  579. ggml_tensor * act_scales,
  580. llm_ffn_op_type type_op,
  581. llm_ffn_gate_type type_gate,
  582. int il) const;
  583. // build MoE FFN without bias tensors
  584. ggml_tensor * build_moe_ffn(
  585. ggml_tensor * cur,
  586. ggml_tensor * gate_inp,
  587. ggml_tensor * up_exps,
  588. ggml_tensor * gate_exps,
  589. ggml_tensor * down_exps,
  590. ggml_tensor * exp_probs_b,
  591. int64_t n_expert,
  592. int64_t n_expert_used,
  593. llm_ffn_op_type type_op,
  594. bool norm_w,
  595. bool scale_w,
  596. float w_scale,
  597. llama_expert_gating_func_type gating_op,
  598. int il,
  599. ggml_tensor * probs_in = nullptr) const;
  600. ggml_tensor * build_moe_ffn(
  601. ggml_tensor * cur,
  602. ggml_tensor * gate_inp,
  603. ggml_tensor * gate_inp_b,
  604. ggml_tensor * up_exps,
  605. ggml_tensor * up_exps_b,
  606. ggml_tensor * gate_exps,
  607. ggml_tensor * gate_exps_b,
  608. ggml_tensor * down_exps,
  609. ggml_tensor * down_exps_b,
  610. ggml_tensor * exp_probs_b,
  611. int64_t n_expert,
  612. int64_t n_expert_used,
  613. llm_ffn_op_type type_op,
  614. bool norm_w,
  615. bool scale_w,
  616. float w_scale,
  617. llama_expert_gating_func_type gating_op,
  618. int il,
  619. ggml_tensor * probs_in = nullptr) const;
  620. //
  621. // inputs
  622. //
  623. ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
  624. ggml_tensor * build_inp_pos() const;
  625. ggml_tensor * build_inp_attn_scale() const;
  626. ggml_tensor * build_inp_out_ids() const;
  627. ggml_tensor * build_inp_mean() const;
  628. ggml_tensor * build_inp_cls() const;
  629. ggml_tensor * build_inp_cross_embd() const;
  630. ggml_tensor * build_inp_pos_bucket_enc() const;
  631. ggml_tensor * build_inp_pos_bucket_dec() const;
  632. ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
  633. //
  634. // attention
  635. //
  636. ggml_tensor * build_attn_mha(
  637. ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
  638. ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
  639. ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
  640. ggml_tensor * kq_b,
  641. ggml_tensor * kq_mask,
  642. ggml_tensor * sinks, // [n_head_q]
  643. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  644. float kq_scale,
  645. int il) const;
  646. llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
  647. ggml_tensor * build_attn(
  648. llm_graph_input_attn_no_cache * inp,
  649. ggml_tensor * wo,
  650. ggml_tensor * wo_b,
  651. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  652. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  653. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  654. ggml_tensor * kq_b,
  655. ggml_tensor * sinks, // [n_head_q]
  656. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  657. float kq_scale,
  658. int il) const;
  659. llm_graph_input_attn_kv * build_attn_inp_kv() const;
  660. ggml_tensor * build_attn(
  661. llm_graph_input_attn_kv * inp,
  662. ggml_tensor * wo,
  663. ggml_tensor * wo_b,
  664. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  665. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  666. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  667. ggml_tensor * kq_b,
  668. ggml_tensor * sinks, // [n_head_q]
  669. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
  670. float kq_scale,
  671. int il) const;
  672. llm_graph_input_attn_k * build_attn_inp_k() const;
  673. ggml_tensor * build_attn(
  674. llm_graph_input_attn_k * inp,
  675. ggml_tensor * wo,
  676. ggml_tensor * wo_b,
  677. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  678. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  679. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  680. ggml_tensor * kq_b,
  681. ggml_tensor * sinks, // [n_head_q]
  682. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  683. float kq_scale,
  684. int il) const;
  685. llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
  686. // note: if k_cur or v_cur are not provided, they will not be stored in the memory
  687. ggml_tensor * build_attn(
  688. llm_graph_input_attn_kv_iswa * inp,
  689. ggml_tensor * wo,
  690. ggml_tensor * wo_b,
  691. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  692. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
  693. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
  694. ggml_tensor * kq_b,
  695. ggml_tensor * sinks, // [n_head_q]
  696. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  697. float kq_scale,
  698. int il) const;
  699. llm_graph_input_attn_cross * build_attn_inp_cross() const;
  700. ggml_tensor * build_attn(
  701. llm_graph_input_attn_cross * inp,
  702. ggml_tensor * wo,
  703. ggml_tensor * wo_b,
  704. ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
  705. ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
  706. ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
  707. ggml_tensor * kq_b,
  708. ggml_tensor * sinks, // [n_head_q]
  709. ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
  710. float kq_scale,
  711. int il) const;
  712. //
  713. // recurrent
  714. //
  715. // TODO: move this implementation to llama_memory_recurrent.
  716. // this is analogous to llama_kv_cache::cpy_k / cpy_v
  717. // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
  718. // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
  719. // `llama_memory_recurrent`
  720. ggml_tensor * build_rs(
  721. ggml_tensor * s,
  722. ggml_tensor * state_copy_main,
  723. ggml_tensor * state_copy_extra,
  724. int32_t state_size,
  725. int32_t n_seqs,
  726. uint32_t n_rs,
  727. uint32_t rs_head,
  728. uint32_t rs_size,
  729. int32_t rs_zero,
  730. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  731. llm_graph_input_rs * build_rs_inp() const;
  732. ggml_tensor * build_rs(
  733. llm_graph_input_rs * inp,
  734. ggml_tensor * s,
  735. int32_t state_size,
  736. int32_t n_seqs,
  737. const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
  738. ggml_tensor * build_rwkv_token_shift_load(
  739. llm_graph_input_rs * inp,
  740. const llama_ubatch & ubatch,
  741. int il) const;
  742. ggml_tensor * build_rwkv_token_shift_store(
  743. ggml_tensor * token_shift,
  744. const llama_ubatch & ubatch,
  745. int il) const;
  746. //
  747. // hybrid
  748. //
  749. llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
  750. llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
  751. //
  752. // pooling
  753. //
  754. void build_pooling(
  755. ggml_tensor * cls,
  756. ggml_tensor * cls_b,
  757. ggml_tensor * cls_out,
  758. ggml_tensor * cls_out_b) const;
  759. //
  760. // sampling (backend sampling)
  761. //
  762. void build_sampling() const;
  763. //
  764. // dense (out)
  765. //
  766. void build_dense_out(
  767. ggml_tensor * dense_2,
  768. ggml_tensor * dense_3) const;
  769. };
  770. // TODO: better name
  771. int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);