llama-kv-cache.cpp 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052
  1. #include "llama-kv-cache.h"
  2. #include "llama-impl.h"
  3. #include "llama-io.h"
  4. #include "llama-model.h"
  5. #include "llama-context.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <cstring>
  10. #include <limits>
  11. #include <map>
  12. #include <stdexcept>
  13. //
  14. // llama_kv_cache
  15. //
  16. llama_kv_cache::llama_kv_cache(
  17. const llama_model & model,
  18. ggml_type type_k,
  19. ggml_type type_v,
  20. bool v_trans,
  21. bool offload,
  22. bool unified,
  23. uint32_t kv_size,
  24. uint32_t n_seq_max,
  25. uint32_t n_pad,
  26. uint32_t n_swa,
  27. llama_swa_type swa_type,
  28. const layer_filter_cb & filter,
  29. const layer_reuse_cb & reuse) :
  30. model(model), hparams(model.hparams), v_trans(v_trans),
  31. n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  32. GGML_ASSERT(kv_size % n_pad == 0);
  33. const uint32_t n_layer_kv = hparams.n_layer_kv();
  34. // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
  35. struct ggml_backend_buft_comparator {
  36. bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
  37. return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
  38. }
  39. };
  40. std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
  41. // create a context for each buffer type
  42. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  43. auto it = ctx_map.find(buft);
  44. if (it == ctx_map.end()) {
  45. ggml_init_params params = {
  46. /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
  47. /*.mem_buffer =*/ NULL,
  48. /*.no_alloc =*/ true,
  49. };
  50. ggml_context * ctx = ggml_init(params);
  51. if (!ctx) {
  52. return nullptr;
  53. }
  54. ctx_map.emplace(buft, ctx);
  55. return ctx;
  56. }
  57. return it->second.get();
  58. };
  59. GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
  60. v_heads.resize(n_stream);
  61. for (uint32_t s = 0; s < n_stream; ++s) {
  62. v_heads[s] = 0;
  63. }
  64. v_cells.resize(n_stream);
  65. for (uint32_t s = 0; s < n_stream; ++s) {
  66. v_cells[s].resize(kv_size);
  67. }
  68. // by default, all sequence ids are mapped to the 0th stream
  69. seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
  70. if (n_stream > 1) {
  71. seq_to_stream.resize(n_stream, 0);
  72. for (uint32_t s = 0; s < n_stream; ++s) {
  73. seq_to_stream[s] = s;
  74. }
  75. }
  76. // [TAG_V_CACHE_VARIABLE]
  77. if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
  78. LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
  79. __func__, hparams.n_embd_v_gqa_max());
  80. }
  81. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  82. if (!hparams.has_kv(il)) {
  83. LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
  84. continue;
  85. }
  86. if (filter && !filter(il)) {
  87. LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
  88. continue;
  89. }
  90. // [TAG_V_CACHE_VARIABLE]
  91. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  92. const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
  93. const char * dev_name = "CPU";
  94. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  95. if (offload) {
  96. auto * dev = model.dev_layer(il);
  97. buft = ggml_backend_dev_buffer_type(dev);
  98. dev_name = ggml_backend_dev_name(dev);
  99. }
  100. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  101. ggml_context * ctx = ctx_for_buft(buft);
  102. if (!ctx) {
  103. throw std::runtime_error("failed to create ggml context for kv cache");
  104. }
  105. ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
  106. ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
  107. ggml_format_name(k, "cache_k_l%d", il);
  108. ggml_format_name(v, "cache_v_l%d", il);
  109. std::vector<ggml_tensor *> k_stream;
  110. std::vector<ggml_tensor *> v_stream;
  111. for (uint32_t s = 0; s < n_stream; ++s) {
  112. k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
  113. v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
  114. }
  115. map_layer_ids[il] = layers.size();
  116. layers.push_back({ il, k, v, k_stream, v_stream, });
  117. }
  118. if (reuse) {
  119. LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
  120. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  121. const int32_t il_reuse = reuse(il);
  122. if (il_reuse < 0) {
  123. LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
  124. continue;
  125. }
  126. if (filter && !filter(il)) {
  127. LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
  128. continue;
  129. }
  130. GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
  131. map_layer_ids[il] = map_layer_ids[il_reuse];
  132. LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
  133. }
  134. }
  135. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  136. for (auto & [buft, ctx] : ctx_map) {
  137. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
  138. if (!buf) {
  139. throw std::runtime_error("failed to allocate buffer for kv cache");
  140. }
  141. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  142. ggml_backend_buffer_clear(buf, 0);
  143. ctxs_bufs.emplace_back(std::move(ctx), buf);
  144. }
  145. {
  146. const size_t memory_size_k = size_k_bytes();
  147. const size_t memory_size_v = size_v_bytes();
  148. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  149. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
  150. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  151. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  152. }
  153. const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
  154. debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
  155. }
  156. void llama_kv_cache::clear(bool data) {
  157. for (uint32_t s = 0; s < n_stream; ++s) {
  158. v_cells[s].reset();
  159. v_heads[s] = 0;
  160. }
  161. if (data) {
  162. for (auto & [_, buf] : ctxs_bufs) {
  163. ggml_backend_buffer_clear(buf.get(), 0);
  164. }
  165. }
  166. }
  167. bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  168. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  169. if (p0 < 0) {
  170. p0 = 0;
  171. }
  172. if (p1 < 0) {
  173. p1 = std::numeric_limits<llama_pos>::max();
  174. }
  175. if (seq_id >= 0) {
  176. auto & cells = v_cells[seq_to_stream[seq_id]];
  177. auto & head = v_heads[seq_to_stream[seq_id]];
  178. uint32_t new_head = cells.size();
  179. for (uint32_t i = 0; i < cells.size(); ++i) {
  180. if (!cells.pos_in(i, p0, p1)) {
  181. continue;
  182. }
  183. if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
  184. if (new_head == cells.size()) {
  185. new_head = i;
  186. }
  187. }
  188. }
  189. // If we freed up a slot, set head to it so searching can start there.
  190. if (new_head != cells.size() && new_head < head) {
  191. head = new_head;
  192. }
  193. } else {
  194. // match any sequence
  195. for (uint32_t s = 0; s < n_stream; ++s) {
  196. auto & cells = v_cells[s];
  197. auto & head = v_heads[s];
  198. uint32_t new_head = cells.size();
  199. for (uint32_t i = 0; i < cells.size(); ++i) {
  200. if (!cells.pos_in(i, p0, p1)) {
  201. continue;
  202. }
  203. cells.rm(i);
  204. if (new_head == cells.size()) {
  205. new_head = i;
  206. }
  207. }
  208. // If we freed up a slot, set head to it so searching can start there.
  209. if (new_head != cells.size() && new_head < head) {
  210. head = new_head;
  211. }
  212. }
  213. }
  214. return true;
  215. }
  216. void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  217. GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
  218. GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
  219. const auto s0 = seq_to_stream[seq_id_src];
  220. const auto s1 = seq_to_stream[seq_id_dst];
  221. if (s0 == s1) {
  222. // since both sequences are in the same stream, no data copy is necessary
  223. // we just have to update the cells meta data
  224. auto & cells = v_cells[s0];
  225. if (seq_id_src == seq_id_dst) {
  226. return;
  227. }
  228. if (p0 < 0) {
  229. p0 = 0;
  230. }
  231. if (p1 < 0) {
  232. p1 = std::numeric_limits<llama_pos>::max();
  233. }
  234. for (uint32_t i = 0; i < cells.size(); ++i) {
  235. if (!cells.pos_in(i, p0, p1)) {
  236. continue;
  237. }
  238. if (cells.seq_has(i, seq_id_src)) {
  239. cells.seq_add(i, seq_id_dst);
  240. }
  241. }
  242. return;
  243. }
  244. // cross-stream sequence copies require to copy the actual buffer data
  245. bool is_full = true;
  246. if (p0 > 0 && p0 + 1 < (int) get_size()) {
  247. is_full = false;
  248. }
  249. if (p1 > 0 && p1 + 1 < (int) get_size()) {
  250. is_full = false;
  251. }
  252. GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
  253. // enqueue the copy operation - the buffer copy will be performed during the next update
  254. sc_info.ssrc.push_back(s0);
  255. sc_info.sdst.push_back(s1);
  256. v_cells[s1].reset();
  257. for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
  258. if (v_cells[s0].seq_has(i, seq_id_src)) {
  259. llama_pos pos = v_cells[s0].pos_get(i);
  260. llama_pos shift = v_cells[s0].get_shift(i);
  261. llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
  262. if (shift != 0) {
  263. pos -= shift;
  264. assert(pos >= 0);
  265. }
  266. v_cells[s1].pos_set(i, pos);
  267. v_cells[s1].seq_add(i, seq_id_dst);
  268. if (shift != 0) {
  269. v_cells[s1].pos_add(i, shift);
  270. }
  271. v_cells[s1].ext_set(i, ext);
  272. }
  273. }
  274. v_heads[s1] = v_heads[s0];
  275. //for (uint32_t s = 0; s < n_stream; ++s) {
  276. // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
  277. //}
  278. }
  279. void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
  280. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  281. auto & cells = v_cells[seq_to_stream[seq_id]];
  282. auto & head = v_heads[seq_to_stream[seq_id]];
  283. uint32_t new_head = cells.size();
  284. for (uint32_t i = 0; i < cells.size(); ++i) {
  285. if (cells.seq_keep(i, seq_id)) {
  286. if (new_head == cells.size()) {
  287. new_head = i;
  288. }
  289. }
  290. }
  291. // If we freed up a slot, set head to it so searching can start there.
  292. if (new_head != cells.size() && new_head < head) {
  293. head = new_head;
  294. }
  295. }
  296. void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
  297. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  298. GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
  299. auto & cells = v_cells[seq_to_stream[seq_id]];
  300. auto & head = v_heads[seq_to_stream[seq_id]];
  301. if (shift == 0) {
  302. return;
  303. }
  304. uint32_t new_head = cells.size();
  305. if (p0 < 0) {
  306. p0 = 0;
  307. }
  308. if (p1 < 0) {
  309. p1 = std::numeric_limits<llama_pos>::max();
  310. }
  311. // If there is no range then return early to avoid looping over all cells.
  312. if (p0 == p1) {
  313. return;
  314. }
  315. for (uint32_t i = 0; i < cells.size(); ++i) {
  316. if (!cells.pos_in(i, p0, p1)) {
  317. continue;
  318. }
  319. if (cells.seq_has(i, seq_id)) {
  320. if (cells.pos_add(i, shift)) {
  321. if (new_head == cells.size()) {
  322. new_head = i;
  323. }
  324. }
  325. }
  326. }
  327. // If we freed up a slot, set head to it so searching can start there.
  328. // Otherwise we just start the next search from the beginning.
  329. head = new_head != cells.size() ? new_head : 0;
  330. }
  331. void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  332. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  333. GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
  334. auto & cells = v_cells[seq_to_stream[seq_id]];
  335. if (d == 1) {
  336. return;
  337. }
  338. if (p0 < 0) {
  339. p0 = 0;
  340. }
  341. if (p1 < 0) {
  342. p1 = std::numeric_limits<llama_pos>::max();
  343. }
  344. // If there is no range then return early to avoid looping over the cache.
  345. if (p0 == p1) {
  346. return;
  347. }
  348. for (uint32_t i = 0; i < cells.size(); ++i) {
  349. if (!cells.pos_in(i, p0, p1)) {
  350. continue;
  351. }
  352. if (cells.seq_has(i, seq_id)) {
  353. cells.pos_div(i, d);
  354. }
  355. }
  356. }
  357. llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
  358. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  359. const auto & cells = v_cells[seq_to_stream[seq_id]];
  360. return cells.seq_pos_min(seq_id);
  361. }
  362. llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
  363. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  364. const auto & cells = v_cells[seq_to_stream[seq_id]];
  365. return cells.seq_pos_max(seq_id);
  366. }
  367. std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
  368. std::map<ggml_backend_buffer_type_t, size_t> ret;
  369. for (const auto & [_, buf] : ctxs_bufs) {
  370. ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
  371. }
  372. return ret;
  373. }
  374. llama_memory_context_ptr llama_kv_cache::init_batch(
  375. llama_batch_allocr & balloc,
  376. uint32_t n_ubatch,
  377. bool embd_all) {
  378. GGML_UNUSED(embd_all);
  379. do {
  380. balloc.split_reset();
  381. std::vector<llama_ubatch> ubatches;
  382. while (true) {
  383. auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
  384. if (ubatch.n_tokens == 0) {
  385. break;
  386. }
  387. ubatches.push_back(std::move(ubatch)); // NOLINT
  388. }
  389. if (balloc.get_n_used() < balloc.get_n_tokens()) {
  390. // failed to find a suitable split
  391. break;
  392. }
  393. auto sinfos = prepare(ubatches);
  394. if (sinfos.empty()) {
  395. break;
  396. }
  397. return std::make_unique<llama_kv_cache_context>(
  398. this, std::move(sinfos), std::move(ubatches));
  399. } while (false);
  400. return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
  401. }
  402. llama_memory_context_ptr llama_kv_cache::init_full() {
  403. return std::make_unique<llama_kv_cache_context>(this);
  404. }
  405. llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
  406. GGML_UNUSED(optimize);
  407. bool do_shift = get_has_shift();
  408. return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
  409. }
  410. llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
  411. llama_kv_cache::slot_info_vec_t res;
  412. struct state_t {
  413. slot_info sinfo; // slot info for the ubatch
  414. std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
  415. std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
  416. };
  417. // remember the old state of the cells so we can restore it in the end
  418. std::vector<state_t> states;
  419. bool success = true;
  420. for (const auto & ubatch : ubatches) {
  421. // only find a suitable slot for the ubatch. don't modify the cells yet
  422. const auto sinfo_new = find_slot(ubatch, false);
  423. if (sinfo_new.empty()) {
  424. success = false;
  425. break;
  426. }
  427. // remeber the position that we found
  428. res.push_back(sinfo_new);
  429. // store the old state of the cells in the recovery stack
  430. {
  431. state_t state = { sinfo_new, v_heads, {} };
  432. for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
  433. auto & cells = v_cells[sinfo_new.strm[s]];
  434. state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
  435. }
  436. states.push_back(std::move(state));
  437. }
  438. // now emplace the ubatch
  439. apply_ubatch(sinfo_new, ubatch);
  440. }
  441. GGML_ASSERT(!states.empty() || !success);
  442. // iterate backwards and restore the cells to their original state
  443. for (auto it = states.rbegin(); it != states.rend(); ++it) {
  444. const auto & sinfo = it->sinfo;
  445. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  446. auto & cells = v_cells[sinfo.strm[s]];
  447. auto & head = v_heads[sinfo.strm[s]];
  448. cells.set(sinfo.idxs[s], it->v_cells[s]);
  449. head = it->v_heads_old[s];
  450. }
  451. }
  452. if (!success) {
  453. return {};
  454. }
  455. return res;
  456. }
  457. bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
  458. bool updated = false;
  459. auto * sched = lctx->get_sched();
  460. if (!sc_info.empty()) {
  461. assert(n_stream > 1 && "stream copy should never happen with a single stream");
  462. llama_synchronize(lctx);
  463. const size_t n_copy = sc_info.ssrc.size();
  464. for (size_t i = 0; i < n_copy; ++i) {
  465. const auto ssrc = sc_info.ssrc[i];
  466. const auto sdst = sc_info.sdst[i];
  467. assert(ssrc < n_stream);
  468. assert(sdst < n_stream);
  469. LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
  470. assert(ssrc != sdst);
  471. for (uint32_t il = 0; il < layers.size(); ++il) {
  472. const auto & layer = layers[il];
  473. ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
  474. ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
  475. }
  476. }
  477. }
  478. if (do_shift) {
  479. if (!get_can_shift()) {
  480. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  481. }
  482. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  483. // apply K-shift if needed
  484. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  485. ggml_backend_sched_reset(sched);
  486. auto * res = lctx->get_gf_res_reserve();
  487. res->reset();
  488. auto * gf = build_graph_shift(res, lctx);
  489. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  490. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
  491. return updated;
  492. }
  493. res->set_inputs(nullptr);
  494. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  495. LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
  496. return updated;
  497. }
  498. updated = true;
  499. }
  500. for (uint32_t s = 0; s < n_stream; ++s) {
  501. auto & cells = v_cells[s];
  502. cells.reset_shift();
  503. }
  504. }
  505. return updated;
  506. }
  507. llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
  508. if (debug > 0) {
  509. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  510. const auto seq_id = ubatch.seq_id_unq[s];
  511. const auto stream_id = seq_to_stream[seq_id];
  512. const auto & cells = v_cells[stream_id];
  513. const uint32_t head_cur = v_heads[stream_id];
  514. LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
  515. __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
  516. if ((debug == 2 && n_swa > 0) || debug > 2) {
  517. std::string ss;
  518. for (uint32_t i = 0; i < cells.size(); ++i) {
  519. if (cells.is_empty(i)) {
  520. ss += '.';
  521. } else {
  522. assert(cells.seq_count(i) >= 1);
  523. if (cells.seq_count(i) == 1) {
  524. ss += std::to_string(cells.seq_get(i));
  525. } else {
  526. ss += 'M';
  527. }
  528. }
  529. if (i%256 == 255) {
  530. ss += " *";
  531. ss += '\n';
  532. }
  533. }
  534. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  535. }
  536. if ((debug == 2 && n_swa > 0) || debug > 2) {
  537. std::string ss;
  538. for (uint32_t i = 0; i < cells.size(); ++i) {
  539. std::string cur;
  540. if (cells.is_empty(i)) {
  541. cur = '.';
  542. } else {
  543. cur = std::to_string(cells.pos_get(i));
  544. }
  545. const int n = cur.size();
  546. for (int j = 0; j < 5 - n; ++j) {
  547. cur += ' ';
  548. }
  549. ss += cur;
  550. if (i%256 == 255) {
  551. ss += " *";
  552. }
  553. if (i%64 == 63) {
  554. ss += '\n';
  555. }
  556. }
  557. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  558. }
  559. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  560. if (cells.seq_pos_min(s) < 0) {
  561. continue;
  562. }
  563. LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
  564. }
  565. }
  566. }
  567. uint32_t n_tokens = ubatch.n_tokens;
  568. uint32_t n_seqs = 1;
  569. if (n_stream > 1) {
  570. GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
  571. n_seqs = ubatch.n_seqs_unq;
  572. n_tokens = n_tokens / n_seqs;
  573. }
  574. slot_info res = {
  575. /*.s0 =*/ LLAMA_MAX_SEQ,
  576. /*.s1 =*/ 0,
  577. /*.strm =*/ { },
  578. /*.idxs =*/ { },
  579. };
  580. res.resize(n_seqs);
  581. for (uint32_t s = 0; s < n_seqs; ++s) {
  582. const auto seq_id = ubatch.seq_id_unq[s];
  583. if (n_stream > 1) {
  584. GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
  585. GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
  586. }
  587. res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
  588. res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
  589. res.strm[s] = seq_to_stream[seq_id];
  590. res.idxs[s].reserve(n_tokens);
  591. const auto & cells = v_cells[seq_to_stream[seq_id]];
  592. uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
  593. // if we have enough unused cells before the current head ->
  594. // better to start searching from the beginning of the cache, hoping to fill it
  595. if (head_cur > cells.get_used() + 2*n_tokens) {
  596. head_cur = 0;
  597. }
  598. if (n_tokens > cells.size()) {
  599. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
  600. return { };
  601. }
  602. uint32_t n_tested = 0;
  603. // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
  604. // for non-continuous slots, we test the tokens one by one
  605. const uint32_t n_test = cont ? n_tokens : 1;
  606. while (true) {
  607. if (head_cur + n_test > cells.size()) {
  608. n_tested += cells.size() - head_cur;
  609. head_cur = 0;
  610. continue;
  611. }
  612. for (uint32_t i = 0; i < n_test; i++) {
  613. const auto idx = head_cur;
  614. head_cur++;
  615. n_tested++;
  616. //const llama_pos pos = ubatch.pos[i];
  617. //const llama_seq_id seq_id = ubatch.seq_id[i][0];
  618. // can we use this cell? either:
  619. // - the cell is empty
  620. // - the cell is occupied only by one sequence:
  621. // - (disabled) mask causally, if the sequence is the same as the one we are inserting
  622. // - mask SWA, using current max pos for that sequence in the cache
  623. // always insert in the cell with minimum pos
  624. bool can_use = cells.is_empty(idx);
  625. if (!can_use && cells.seq_count(idx) == 1) {
  626. const llama_pos pos_cell = cells.pos_get(idx);
  627. // (disabled) causal mask
  628. // note: it's better to purge any "future" tokens beforehand
  629. //if (cells.seq_has(idx, seq_id)) {
  630. // can_use = pos_cell >= pos;
  631. //}
  632. if (!can_use) {
  633. const llama_seq_id seq_id_cell = cells.seq_get(idx);
  634. // SWA mask
  635. if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
  636. can_use = true;
  637. }
  638. }
  639. }
  640. if (can_use) {
  641. res.idxs[s].push_back(idx);
  642. } else {
  643. if (cont) {
  644. break;
  645. }
  646. }
  647. }
  648. if (res.idxs[s].size() == n_tokens) {
  649. break;
  650. }
  651. if (cont) {
  652. res.idxs[s].clear();
  653. }
  654. if (n_tested >= cells.size()) {
  655. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  656. return { };
  657. }
  658. }
  659. // we didn't find a suitable slot - return empty result
  660. if (res.idxs[s].size() < n_tokens) {
  661. return { };
  662. }
  663. }
  664. assert(res.s1 >= res.s0);
  665. return res;
  666. }
  667. void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
  668. // keep track of the max sequence position that we would overwrite with this ubatch
  669. // for non-SWA cache, this would be always empty
  670. llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
  671. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  672. seq_pos_max_rm[s] = -1;
  673. }
  674. assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
  675. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  676. for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
  677. const uint32_t i = s*sinfo.size() + ii;
  678. auto & cells = v_cells[sinfo.strm[s]];
  679. const auto idx = sinfo.idxs[s][ii];
  680. if (!cells.is_empty(idx)) {
  681. assert(cells.seq_count(idx) == 1);
  682. const llama_seq_id seq_id = cells.seq_get(idx);
  683. const llama_pos pos = cells.pos_get(idx);
  684. seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
  685. cells.rm(idx);
  686. }
  687. cells.pos_set(idx, ubatch.pos[i]);
  688. if (ubatch.is_pos_2d()) {
  689. llama_kv_cell_ext ext {
  690. /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
  691. /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
  692. };
  693. cells.ext_set(idx, ext);
  694. }
  695. for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
  696. cells.seq_add(idx, ubatch.seq_id[i][s]);
  697. }
  698. }
  699. }
  700. // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
  701. // will be present in the cache. so we have to purge any position which is less than those we would overwrite
  702. // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
  703. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  704. if (seq_pos_max_rm[s] == -1) {
  705. continue;
  706. }
  707. GGML_ASSERT(s < seq_to_stream.size());
  708. auto & cells = v_cells[seq_to_stream[s]];
  709. if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
  710. LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
  711. __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
  712. seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
  713. }
  714. }
  715. // move the head at the end of the slot
  716. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  717. auto & head = v_heads[sinfo.strm[s]];
  718. head = sinfo.idxs[s].back() + 1;
  719. }
  720. }
  721. bool llama_kv_cache::get_can_shift() const {
  722. return true;
  723. }
  724. uint32_t llama_kv_cache::get_size() const {
  725. const auto & cells = v_cells[seq_to_stream[0]];
  726. return cells.size();
  727. }
  728. uint32_t llama_kv_cache::get_n_stream() const {
  729. return n_stream;
  730. }
  731. bool llama_kv_cache::get_has_shift() const {
  732. bool result = false;
  733. for (uint32_t s = 0; s < n_stream; ++s) {
  734. result |= v_cells[s].get_has_shift();
  735. }
  736. return result;
  737. }
  738. uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
  739. uint32_t result = 0;
  740. // pad the n_kv value so that the graph remains constant across batches and can be reused
  741. // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
  742. const uint32_t n_pad_cur = std::max(n_pad, 256u);
  743. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  744. const auto & cells = v_cells[sinfo.strm[s]];
  745. result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
  746. }
  747. return result;
  748. }
  749. ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  750. const int32_t ikv = map_layer_ids.at(il);
  751. auto * k = layers[ikv].k;
  752. const uint64_t kv_size = get_size();
  753. const uint64_t n_embd_k_gqa = k->ne[0];
  754. assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
  755. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  756. return ggml_view_4d(ctx, k,
  757. hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
  758. ggml_row_size(k->type, hparams.n_embd_head_k),
  759. ggml_row_size(k->type, n_embd_k_gqa),
  760. ggml_row_size(k->type, n_embd_k_gqa*kv_size),
  761. ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
  762. }
  763. ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  764. const int32_t ikv = map_layer_ids.at(il);
  765. auto * v = layers[ikv].v;
  766. const uint64_t kv_size = get_size();
  767. const uint64_t n_embd_v_gqa = v->ne[0];
  768. // [TAG_V_CACHE_VARIABLE]
  769. assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
  770. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  771. if (!v_trans) {
  772. // note: v->nb[1] <= v->nb[2]
  773. return ggml_view_4d(ctx, v,
  774. hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
  775. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  776. ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
  777. ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
  778. ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
  779. }
  780. // note: v->nb[1] > v->nb[2]
  781. return ggml_view_4d(ctx, v,
  782. n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
  783. ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
  784. ggml_row_size(v->type, kv_size), // v->nb[2]
  785. ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
  786. ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
  787. }
  788. ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
  789. GGML_UNUSED(sinfo);
  790. const int32_t ikv = map_layer_ids.at(il);
  791. ggml_tensor * k = layers[ikv].k;
  792. const int64_t n_embd_head = k_cur->ne[0];
  793. const int64_t n_head = k_cur->ne[1];
  794. const int64_t n_tokens = k_cur->ne[2];
  795. const int64_t n_embd_gqa = n_embd_head*n_head;
  796. // we can merge dims 0 and 1
  797. // TODO: add ggml helper function for this?
  798. GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
  799. k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
  800. const int64_t n_stream = k->ne[2];
  801. if (n_stream > 1) {
  802. const int64_t kv_size = get_size();
  803. assert(n_embd_gqa == k->ne[0]);
  804. assert(kv_size == k->ne[1]);
  805. // merge the buffer across all streams because the idxs are global
  806. k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
  807. }
  808. // store the current K values into the cache
  809. return ggml_set_rows(ctx, k, k_cur, k_idxs);
  810. }
  811. ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
  812. GGML_UNUSED(sinfo);
  813. const int32_t ikv = map_layer_ids.at(il);
  814. auto * v = layers[ikv].v;
  815. const int64_t n_embd_head = v_cur->ne[0];
  816. const int64_t n_head = v_cur->ne[1];
  817. const int64_t n_tokens = v_cur->ne[2];
  818. const int64_t n_embd_gqa = n_embd_head*n_head;
  819. // we can merge dims 0 and 1
  820. GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
  821. const int64_t n_stream = v->ne[2];
  822. // take this branch when FA is enabled (the V cache is not transposed)
  823. if (!v_trans) {
  824. v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
  825. if (n_stream > 1) {
  826. const int64_t kv_size = get_size();
  827. assert(n_embd_gqa == v->ne[0]);
  828. assert(kv_size == v->ne[1]);
  829. // merge the buffer across all streams because the idxs are global
  830. v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
  831. }
  832. return ggml_set_rows(ctx, v, v_cur, v_idxs);
  833. }
  834. if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
  835. // we can merge dims 0, 1 and 2
  836. v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
  837. } else {
  838. // otherwise -> make a copy to get contiguous data
  839. v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
  840. }
  841. // [TAG_V_CACHE_VARIABLE]
  842. if (n_embd_gqa < v->ne[0]) {
  843. v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
  844. }
  845. // in this branch the v_idxs are constructed in such a way that each row is a single head element
  846. ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
  847. v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
  848. return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
  849. }
  850. ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  851. const uint32_t n_tokens = ubatch.n_tokens;
  852. ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  853. ggml_set_input(k_idxs);
  854. return k_idxs;
  855. }
  856. ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  857. const uint32_t n_tokens = ubatch.n_tokens;
  858. ggml_tensor * v_idxs;
  859. if (!v_trans) {
  860. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  861. } else {
  862. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
  863. }
  864. ggml_set_input(v_idxs);
  865. return v_idxs;
  866. }
  867. void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  868. const uint32_t n_tokens = ubatch->n_tokens;
  869. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  870. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  871. int64_t * data = (int64_t *) dst->data;
  872. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  873. const int64_t offs = sinfo.strm[s]*get_size();
  874. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  875. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  876. }
  877. }
  878. }
  879. void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  880. const uint32_t n_tokens = ubatch->n_tokens;
  881. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  882. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  883. int64_t * data = (int64_t *) dst->data;
  884. if (!v_trans) {
  885. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  886. const int64_t offs = sinfo.strm[s]*get_size();
  887. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  888. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  889. }
  890. }
  891. } else {
  892. // note: the V cache is transposed when not using flash attention
  893. const int64_t kv_size = get_size();
  894. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
  895. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  896. const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
  897. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  898. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  899. data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
  900. }
  901. }
  902. }
  903. }
  904. }
  905. void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
  906. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  907. int32_t * data = (int32_t *) dst->data;
  908. for (uint32_t s = 0; s < n_stream; ++s) {
  909. const auto & cells = v_cells[s];
  910. for (uint32_t i = 0; i < cells.size(); ++i) {
  911. data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
  912. }
  913. }
  914. }
  915. void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  916. const uint32_t n_tokens = ubatch->n_tokens;
  917. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  918. float * data = (float *) dst->data;
  919. const int64_t n_kv = dst->ne[0];
  920. const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
  921. GGML_ASSERT(n_tokens%n_stream == 0);
  922. // n_tps == n_tokens_per_stream
  923. const int64_t n_tps = n_tokens/n_stream;
  924. const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
  925. std::fill(data, data + ggml_nelements(dst), -INFINITY);
  926. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  927. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  928. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  929. // Causal mask:
  930. // xxx-------
  931. // xxxx------
  932. // xxxxx-----
  933. // Non-causal mask:
  934. // xxxxx-----
  935. // xxxxx-----
  936. // xxxxx-----
  937. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  938. // TODO: optimize this section
  939. for (uint32_t h = 0; h < 1; ++h) {
  940. for (uint32_t s = 0; s < n_stream; ++s) {
  941. for (uint32_t ii = 0; ii < n_tps; ++ii) {
  942. const uint32_t i = s*n_tps + ii;
  943. const llama_seq_id seq_id = ubatch->seq_id[i][0];
  944. const auto & cells = v_cells[seq_to_stream[seq_id]];
  945. const llama_pos p1 = ubatch->pos[i];
  946. // for M-RoPE
  947. const bool is_2d = ubatch->is_pos_2d();
  948. const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
  949. const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
  950. const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
  951. for (uint32_t j = 0; j < n_kv; ++j) {
  952. if (cells.is_empty(j)) {
  953. continue;
  954. }
  955. // mask the token if not the same sequence
  956. if (!cells.seq_has(j, seq_id)) {
  957. continue;
  958. }
  959. const llama_pos p0 = cells.pos_get(j);
  960. // mask future tokens
  961. if (causal_attn && p0 > p1) {
  962. continue;
  963. }
  964. // M-RoPE causal mask
  965. if (causal_attn && is_2d && p0 == p1) {
  966. const auto & p0_ext = cells.ext_get(j);
  967. if (p0_ext.is_2d_gt(p1_x, p1_y)) {
  968. continue;
  969. }
  970. }
  971. // apply SWA if any
  972. if (is_masked_swa(p0, p1)) {
  973. continue;
  974. }
  975. data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
  976. }
  977. }
  978. }
  979. }
  980. }
  981. void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  982. const int64_t n_tokens = ubatch->n_tokens;
  983. GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
  984. const auto & cells = v_cells[0];
  985. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  986. GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
  987. int32_t * data = (int32_t *) dst->data;
  988. const int32_t n_kv = dst->ne[0];
  989. for (int h = 0; h < 1; ++h) {
  990. for (int i = 0; i < n_tokens; ++i) {
  991. for (int j = 0; j < n_kv; ++j) {
  992. // the position when the cells is empty is irrelevant - it will be masked out later in the attention
  993. const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
  994. data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
  995. }
  996. }
  997. }
  998. }
  999. size_t llama_kv_cache::total_size() const {
  1000. size_t size = 0;
  1001. for (const auto & [_, buf] : ctxs_bufs) {
  1002. size += ggml_backend_buffer_get_size(buf.get());
  1003. }
  1004. return size;
  1005. }
  1006. size_t llama_kv_cache::size_k_bytes() const {
  1007. size_t size_k_bytes = 0;
  1008. for (const auto & layer : layers) {
  1009. size_k_bytes += ggml_nbytes(layer.k);
  1010. }
  1011. return size_k_bytes;
  1012. }
  1013. size_t llama_kv_cache::size_v_bytes() const {
  1014. size_t size_v_bytes = 0;
  1015. for (const auto & layer : layers) {
  1016. size_v_bytes += ggml_nbytes(layer.v);
  1017. }
  1018. return size_v_bytes;
  1019. }
  1020. ggml_tensor * llama_kv_cache::build_rope_shift(
  1021. const llama_cparams & cparams,
  1022. ggml_context * ctx,
  1023. ggml_tensor * cur,
  1024. ggml_tensor * shift,
  1025. ggml_tensor * factors,
  1026. float freq_base,
  1027. float freq_scale) const {
  1028. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  1029. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  1030. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  1031. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  1032. const auto & n_rot = hparams.n_rot;
  1033. const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
  1034. // @ngxson : this is a workaround
  1035. // for M-RoPE, we want to rotate the whole vector when doing KV shift
  1036. // a normal RoPE should work, we just need to use the correct ordering
  1037. // ref: https://github.com/ggml-org/llama.cpp/pull/13870
  1038. ? LLAMA_ROPE_TYPE_NEOX
  1039. : hparams.rope_type;
  1040. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  1041. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  1042. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
  1043. ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
  1044. : cparams.yarn_attn_factor;
  1045. ggml_tensor * tmp;
  1046. if (ggml_is_quantized(cur->type)) {
  1047. // dequantize to f32 -> RoPE -> quantize back
  1048. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  1049. tmp = ggml_rope_ext(ctx, tmp,
  1050. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1051. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1052. tmp = ggml_cpy(ctx, tmp, cur);
  1053. } else {
  1054. // we rotate only the first n_rot dimensions
  1055. tmp = ggml_rope_ext_inplace(ctx, cur,
  1056. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1057. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1058. }
  1059. return tmp;
  1060. }
  1061. class llm_graph_input_k_shift : public llm_graph_input_i {
  1062. public:
  1063. llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
  1064. virtual ~llm_graph_input_k_shift() = default;
  1065. void set_input(const llama_ubatch * ubatch) override;
  1066. ggml_tensor * k_shift; // I32 [kv_size*n_stream]
  1067. const llama_kv_cache * kv_self;
  1068. };
  1069. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  1070. GGML_UNUSED(ubatch);
  1071. if (k_shift) {
  1072. kv_self->set_input_k_shift(k_shift);
  1073. }
  1074. }
  1075. ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
  1076. auto * ctx = res->get_ctx();
  1077. auto * gf = res->get_gf();
  1078. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1079. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  1080. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  1081. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
  1082. ggml_set_input(inp->k_shift);
  1083. const auto & cparams = lctx->get_cparams();
  1084. for (const auto & layer : layers) {
  1085. const uint32_t il = layer.il;
  1086. const int64_t n_head_kv = hparams.n_head_kv(il);
  1087. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1088. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  1089. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  1090. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  1091. ggml_tensor * k =
  1092. ggml_view_3d(ctx, layer.k,
  1093. n_embd_head_k, n_head_kv, get_size()*n_stream,
  1094. ggml_row_size(layer.k->type, n_embd_head_k),
  1095. ggml_row_size(layer.k->type, n_embd_k_gqa),
  1096. 0);
  1097. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  1098. ggml_build_forward_expand(gf, cur);
  1099. }
  1100. res->add_input(std::move(inp));
  1101. return gf;
  1102. }
  1103. bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
  1104. return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
  1105. }
  1106. void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
  1107. GGML_UNUSED(flags);
  1108. io.write(&n_stream, sizeof(n_stream));
  1109. for (uint32_t s = 0; s < n_stream; ++s) {
  1110. cell_ranges_t cr { s, {} };
  1111. uint32_t cell_count = 0;
  1112. const auto & cells = v_cells[s];
  1113. // Count the number of cells with the specified seq_id
  1114. // Find all the ranges of cells with this seq id (or all, when -1)
  1115. uint32_t cell_range_begin = cells.size();
  1116. for (uint32_t i = 0; i < cells.size(); ++i) {
  1117. if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
  1118. ++cell_count;
  1119. if (cell_range_begin == cells.size()) {
  1120. cell_range_begin = i;
  1121. }
  1122. } else {
  1123. if (cell_range_begin != cells.size()) {
  1124. cr.data.emplace_back(cell_range_begin, i);
  1125. cell_range_begin = cells.size();
  1126. }
  1127. }
  1128. }
  1129. if (cell_range_begin != cells.size()) {
  1130. cr.data.emplace_back(cell_range_begin, cells.size());
  1131. }
  1132. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1133. uint32_t cell_count_check = 0;
  1134. for (const auto & range : cr.data) {
  1135. cell_count_check += range.second - range.first;
  1136. }
  1137. GGML_ASSERT(cell_count == cell_count_check);
  1138. io.write(&cell_count, sizeof(cell_count));
  1139. // skip empty streams
  1140. if (cell_count == 0) {
  1141. continue;
  1142. }
  1143. state_write_meta(io, cr, seq_id);
  1144. state_write_data(io, cr);
  1145. }
  1146. }
  1147. void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
  1148. GGML_UNUSED(flags);
  1149. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  1150. uint32_t n_stream_cur;
  1151. io.read_to(&n_stream_cur, sizeof(n_stream_cur));
  1152. if (n_stream_cur != n_stream) {
  1153. throw std::runtime_error("n_stream mismatch");
  1154. }
  1155. for (uint32_t s = 0; s < n_stream; ++s) {
  1156. uint32_t cell_count;
  1157. io.read_to(&cell_count, sizeof(cell_count));
  1158. if (cell_count == 0) {
  1159. continue;
  1160. }
  1161. const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
  1162. bool res = true;
  1163. res = res && state_read_meta(io, strm, cell_count, seq_id);
  1164. res = res && state_read_data(io, strm, cell_count);
  1165. if (!res) {
  1166. if (seq_id == -1) {
  1167. clear(true);
  1168. } else {
  1169. seq_rm(seq_id, -1, -1);
  1170. }
  1171. throw std::runtime_error("failed to restore kv cache");
  1172. }
  1173. }
  1174. }
  1175. void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
  1176. const auto & cells = v_cells[cr.strm];
  1177. for (const auto & range : cr.data) {
  1178. for (uint32_t i = range.first; i < range.second; ++i) {
  1179. std::vector<llama_seq_id> seq_ids;
  1180. for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
  1181. if (cur == seq_id || seq_id == -1) {
  1182. if (cells.seq_has(i, cur)) {
  1183. seq_ids.push_back(cur);
  1184. }
  1185. }
  1186. }
  1187. const llama_pos pos = cells.pos_get(i);
  1188. const uint32_t n_seq_id = seq_ids.size();
  1189. io.write(&pos, sizeof(pos));
  1190. io.write(&n_seq_id, sizeof(n_seq_id));
  1191. // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
  1192. // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
  1193. for (const auto & seq_id : seq_ids) {
  1194. io.write(&seq_id, sizeof(seq_id));
  1195. }
  1196. }
  1197. }
  1198. }
  1199. void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
  1200. const auto & cells = v_cells[cr.strm];
  1201. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1202. const uint32_t n_layer = layers.size();
  1203. io.write(&v_trans, sizeof(v_trans));
  1204. io.write(&n_layer, sizeof(n_layer));
  1205. std::vector<uint8_t> tmp_buf;
  1206. // Iterate and write all the keys first, each row is a cell
  1207. // Get whole range at a time
  1208. for (const auto & layer : layers) {
  1209. const uint32_t il = layer.il;
  1210. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1211. auto * k = layer.k_stream[cr.strm];
  1212. // Write key type
  1213. const int32_t k_type_i = (int32_t) k->type;
  1214. io.write(&k_type_i, sizeof(k_type_i));
  1215. // Write row size of key
  1216. const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1217. io.write(&k_size_row, sizeof(k_size_row));
  1218. // Read each range of cells of k_size length each into tmp_buf and write out
  1219. for (const auto & range : cr.data) {
  1220. const size_t range_size = range.second - range.first;
  1221. const size_t buf_size = range_size * k_size_row;
  1222. io.write_tensor(k, range.first * k_size_row, buf_size);
  1223. }
  1224. }
  1225. if (!v_trans) {
  1226. for (const auto & layer : layers) {
  1227. const uint32_t il = layer.il;
  1228. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1229. auto * v = layer.v_stream[cr.strm];
  1230. // Write value type
  1231. const int32_t v_type_i = (int32_t) v->type;
  1232. io.write(&v_type_i, sizeof(v_type_i));
  1233. // Write row size of value
  1234. const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1235. io.write(&v_size_row, sizeof(v_size_row));
  1236. // Read each range of cells of v_size length each into tmp_buf and write out
  1237. for (const auto & range : cr.data) {
  1238. const size_t range_size = range.second - range.first;
  1239. const size_t buf_size = range_size * v_size_row;
  1240. io.write_tensor(v, range.first * v_size_row, buf_size);
  1241. }
  1242. }
  1243. } else {
  1244. // When v is transposed, we also need the element size and get the element ranges from each row
  1245. const uint32_t kv_size = cells.size();
  1246. for (const auto & layer : layers) {
  1247. const uint32_t il = layer.il;
  1248. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1249. auto * v = layer.v_stream[cr.strm];
  1250. // Write value type
  1251. const int32_t v_type_i = (int32_t) v->type;
  1252. io.write(&v_type_i, sizeof(v_type_i));
  1253. // Write element size
  1254. const uint32_t v_size_el = ggml_type_size(v->type);
  1255. io.write(&v_size_el, sizeof(v_size_el));
  1256. // Write GQA embedding size
  1257. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1258. // For each row, we get the element values of each cell
  1259. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1260. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1261. for (const auto & range : cr.data) {
  1262. const size_t range_size = range.second - range.first;
  1263. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1264. const size_t buf_size = range_size * v_size_el;
  1265. io.write_tensor(v, src_offset, buf_size);
  1266. }
  1267. }
  1268. }
  1269. }
  1270. }
  1271. bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1272. auto & cells = v_cells[strm];
  1273. auto & head = v_heads[strm];
  1274. if (dest_seq_id != -1) {
  1275. // single sequence
  1276. seq_rm(dest_seq_id, -1, -1);
  1277. llama_batch_allocr balloc(hparams.n_pos_per_embd());
  1278. llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
  1279. ubatch.seq_id_unq[0] = dest_seq_id;
  1280. for (uint32_t i = 0; i < cell_count; ++i) {
  1281. llama_pos pos;
  1282. uint32_t n_seq_id;
  1283. io.read_to(&pos, sizeof(pos));
  1284. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1285. if (n_seq_id != 1) {
  1286. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1287. return false;
  1288. }
  1289. // read the sequence id, but directly discard it - we will use dest_seq_id instead
  1290. {
  1291. llama_seq_id seq_id;
  1292. io.read_to(&seq_id, sizeof(seq_id));
  1293. }
  1294. ubatch.pos[i] = pos;
  1295. ubatch.n_seq_id[i] = n_seq_id;
  1296. ubatch.seq_id[i] = &dest_seq_id;
  1297. }
  1298. const auto sinfo = find_slot(ubatch, true);
  1299. if (sinfo.empty()) {
  1300. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1301. return false;
  1302. }
  1303. // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
  1304. // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
  1305. apply_ubatch(sinfo, ubatch);
  1306. const auto head_cur = sinfo.head();
  1307. // keep the head at the old position because we will read the KV data into it in state_read_data()
  1308. head = head_cur;
  1309. LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
  1310. // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1311. // Assume that this is one contiguous block of cells
  1312. GGML_ASSERT(head_cur + cell_count <= cells.size());
  1313. GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
  1314. GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
  1315. GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
  1316. GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
  1317. } else {
  1318. // whole KV cache restore
  1319. if (cell_count > cells.size()) {
  1320. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1321. return false;
  1322. }
  1323. clear(true);
  1324. for (uint32_t i = 0; i < cell_count; ++i) {
  1325. llama_pos pos;
  1326. uint32_t n_seq_id;
  1327. io.read_to(&pos, sizeof(pos));
  1328. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1329. cells.pos_set(i, pos);
  1330. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1331. llama_seq_id seq_id;
  1332. io.read_to(&seq_id, sizeof(seq_id));
  1333. if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
  1334. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
  1335. return false;
  1336. }
  1337. cells.seq_add(i, seq_id);
  1338. }
  1339. }
  1340. head = 0;
  1341. }
  1342. return true;
  1343. }
  1344. bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
  1345. auto & cells = v_cells[strm];
  1346. auto & head = v_heads[strm];
  1347. uint32_t v_trans;
  1348. uint32_t n_layer;
  1349. io.read_to(&v_trans, sizeof(v_trans));
  1350. io.read_to(&n_layer, sizeof(n_layer));
  1351. if (n_layer != layers.size()) {
  1352. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1353. return false;
  1354. }
  1355. if (cell_count > cells.size()) {
  1356. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
  1357. return false;
  1358. }
  1359. if (this->v_trans != (bool) v_trans) {
  1360. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1361. return false;
  1362. }
  1363. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1364. for (const auto & layer : layers) {
  1365. const uint32_t il = layer.il;
  1366. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1367. auto * k = layer.k_stream[strm];
  1368. // Read type of key
  1369. int32_t k_type_i_ref;
  1370. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1371. const int32_t k_type_i = (int32_t) k->type;
  1372. if (k_type_i != k_type_i_ref) {
  1373. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1374. return false;
  1375. }
  1376. // Read row size of key
  1377. uint64_t k_size_row_ref;
  1378. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1379. const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1380. if (k_size_row != k_size_row_ref) {
  1381. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1382. return false;
  1383. }
  1384. if (cell_count) {
  1385. // Read and set the keys for the whole cell range
  1386. ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1387. }
  1388. }
  1389. if (!this->v_trans) {
  1390. for (const auto & layer : layers) {
  1391. const uint32_t il = layer.il;
  1392. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1393. auto * v = layer.v_stream[strm];
  1394. // Read type of value
  1395. int32_t v_type_i_ref;
  1396. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1397. const int32_t v_type_i = (int32_t) v->type;
  1398. if (v_type_i != v_type_i_ref) {
  1399. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1400. return false;
  1401. }
  1402. // Read row size of value
  1403. uint64_t v_size_row_ref;
  1404. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1405. const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1406. if (v_size_row != v_size_row_ref) {
  1407. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1408. return false;
  1409. }
  1410. if (cell_count) {
  1411. // Read and set the values for the whole cell range
  1412. ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1413. }
  1414. }
  1415. } else {
  1416. // For each layer, read the values for each cell (transposed)
  1417. for (const auto & layer : layers) {
  1418. const uint32_t il = layer.il;
  1419. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1420. auto * v = layer.v_stream[strm];
  1421. // Read type of value
  1422. int32_t v_type_i_ref;
  1423. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1424. const int32_t v_type_i = (int32_t) v->type;
  1425. if (v_type_i != v_type_i_ref) {
  1426. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1427. return false;
  1428. }
  1429. // Read element size of value
  1430. uint32_t v_size_el_ref;
  1431. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1432. const size_t v_size_el = ggml_type_size(v->type);
  1433. if (v_size_el != v_size_el_ref) {
  1434. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1435. return false;
  1436. }
  1437. // Read GQA embedding size
  1438. uint32_t n_embd_v_gqa_ref;
  1439. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1440. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1441. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1442. return false;
  1443. }
  1444. if (cell_count) {
  1445. // For each row in the transposed matrix, read the values for the whole cell range
  1446. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1447. const size_t dst_offset = (head + j * cells.size()) * v_size_el;
  1448. ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1449. }
  1450. }
  1451. }
  1452. }
  1453. return true;
  1454. }
  1455. //
  1456. // llama_kv_cache_context
  1457. //
  1458. llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
  1459. llama_kv_cache_context::llama_kv_cache_context(
  1460. llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
  1461. n_kv = kv->get_size();
  1462. const uint32_t n_stream = kv->get_n_stream();
  1463. // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
  1464. sinfos.resize(1);
  1465. sinfos[0].s0 = 0;
  1466. sinfos[0].s1 = n_stream - 1;
  1467. sinfos[0].idxs.resize(n_stream);
  1468. for (uint32_t s = 0; s < n_stream; ++s) {
  1469. sinfos[0].strm.push_back(s);
  1470. sinfos[0].idxs[s].resize(1, 0);
  1471. }
  1472. }
  1473. llama_kv_cache_context::llama_kv_cache_context(
  1474. llama_kv_cache * kv,
  1475. llama_context * lctx,
  1476. bool do_shift,
  1477. stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
  1478. if (!do_shift && this->sc_info.empty()) {
  1479. status = LLAMA_MEMORY_STATUS_NO_UPDATE;
  1480. }
  1481. }
  1482. llama_kv_cache_context::llama_kv_cache_context(
  1483. llama_kv_cache * kv,
  1484. llama_kv_cache::slot_info_vec_t sinfos,
  1485. std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
  1486. }
  1487. llama_kv_cache_context::~llama_kv_cache_context() = default;
  1488. bool llama_kv_cache_context::next() {
  1489. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1490. if (++i_cur >= ubatches.size()) {
  1491. return false;
  1492. }
  1493. return true;
  1494. }
  1495. bool llama_kv_cache_context::apply() {
  1496. assert(!llama_memory_status_is_fail(status));
  1497. // no ubatches -> this is a KV cache update
  1498. if (ubatches.empty()) {
  1499. kv->update(lctx, do_shift, sc_info);
  1500. return true;
  1501. }
  1502. kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
  1503. n_kv = kv->get_n_kv(sinfos[i_cur]);
  1504. return true;
  1505. }
  1506. llama_memory_status llama_kv_cache_context::get_status() const {
  1507. return status;
  1508. }
  1509. const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
  1510. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1511. return ubatches[i_cur];
  1512. }
  1513. uint32_t llama_kv_cache_context::get_n_kv() const {
  1514. return n_kv;
  1515. }
  1516. ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
  1517. return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
  1518. }
  1519. ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
  1520. return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
  1521. }
  1522. ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
  1523. return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
  1524. }
  1525. ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
  1526. return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
  1527. }
  1528. ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1529. return kv->build_input_k_idxs(ctx, ubatch);
  1530. }
  1531. ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1532. return kv->build_input_v_idxs(ctx, ubatch);
  1533. }
  1534. void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
  1535. kv->set_input_k_shift(dst);
  1536. }
  1537. void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1538. kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
  1539. }
  1540. void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1541. kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
  1542. }
  1543. void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  1544. kv->set_input_kq_mask(dst, ubatch, causal_attn);
  1545. }
  1546. void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1547. kv->set_input_pos_bucket(dst, ubatch);
  1548. }