llama-kv-cache.cpp 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021
  1. #include "llama-kv-cache.h"
  2. #include "llama-impl.h"
  3. #include "llama-io.h"
  4. #include "llama-model.h"
  5. #include "llama-context.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <cstring>
  10. #include <limits>
  11. #include <map>
  12. #include <stdexcept>
  13. //
  14. // llama_kv_cache
  15. //
  16. llama_kv_cache::llama_kv_cache(
  17. const llama_model & model,
  18. ggml_type type_k,
  19. ggml_type type_v,
  20. bool v_trans,
  21. bool offload,
  22. bool unified,
  23. uint32_t kv_size,
  24. uint32_t n_seq_max,
  25. uint32_t n_pad,
  26. uint32_t n_swa,
  27. llama_swa_type swa_type,
  28. const layer_filter_cb & filter,
  29. const layer_reuse_cb & reuse) :
  30. model(model), hparams(model.hparams), v_trans(v_trans),
  31. n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  32. GGML_ASSERT(kv_size % n_pad == 0);
  33. const uint32_t n_layer_kv = hparams.n_layer_kv();
  34. // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
  35. struct ggml_backend_buft_comparator {
  36. bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
  37. return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
  38. }
  39. };
  40. std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
  41. // create a context for each buffer type
  42. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  43. auto it = ctx_map.find(buft);
  44. if (it == ctx_map.end()) {
  45. ggml_init_params params = {
  46. /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
  47. /*.mem_buffer =*/ NULL,
  48. /*.no_alloc =*/ true,
  49. };
  50. ggml_context * ctx = ggml_init(params);
  51. if (!ctx) {
  52. return nullptr;
  53. }
  54. ctx_map.emplace(buft, ctx);
  55. return ctx;
  56. }
  57. return it->second.get();
  58. };
  59. GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
  60. v_heads.resize(n_stream);
  61. for (uint32_t s = 0; s < n_stream; ++s) {
  62. v_heads[s] = 0;
  63. }
  64. v_cells.resize(n_stream);
  65. for (uint32_t s = 0; s < n_stream; ++s) {
  66. v_cells[s].resize(kv_size);
  67. }
  68. // by default, all sequence ids are mapped to the 0th stream
  69. seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
  70. if (n_stream > 1) {
  71. seq_to_stream.resize(n_stream, 0);
  72. for (uint32_t s = 0; s < n_stream; ++s) {
  73. seq_to_stream[s] = s;
  74. }
  75. }
  76. // [TAG_V_CACHE_VARIABLE]
  77. if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
  78. LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
  79. __func__, hparams.n_embd_v_gqa_max());
  80. }
  81. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  82. if (!hparams.has_kv(il)) {
  83. LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
  84. continue;
  85. }
  86. if (filter && !filter(il)) {
  87. LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
  88. continue;
  89. }
  90. // [TAG_V_CACHE_VARIABLE]
  91. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  92. const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
  93. const char * dev_name = "CPU";
  94. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  95. if (offload) {
  96. auto * dev = model.dev_layer(il);
  97. buft = ggml_backend_dev_buffer_type(dev);
  98. dev_name = ggml_backend_dev_name(dev);
  99. }
  100. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  101. ggml_context * ctx = ctx_for_buft(buft);
  102. if (!ctx) {
  103. throw std::runtime_error("failed to create ggml context for kv cache");
  104. }
  105. ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
  106. ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
  107. ggml_format_name(k, "cache_k_l%d", il);
  108. ggml_format_name(v, "cache_v_l%d", il);
  109. std::vector<ggml_tensor *> k_stream;
  110. std::vector<ggml_tensor *> v_stream;
  111. for (uint32_t s = 0; s < n_stream; ++s) {
  112. k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
  113. v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
  114. }
  115. map_layer_ids[il] = layers.size();
  116. layers.push_back({ il, k, v, k_stream, v_stream, });
  117. }
  118. if (reuse) {
  119. LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
  120. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  121. const int32_t il_reuse = reuse(il);
  122. if (il_reuse < 0) {
  123. LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
  124. continue;
  125. }
  126. if (filter && !filter(il)) {
  127. LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
  128. continue;
  129. }
  130. GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
  131. map_layer_ids[il] = map_layer_ids[il_reuse];
  132. LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
  133. }
  134. }
  135. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  136. for (auto & [buft, ctx] : ctx_map) {
  137. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
  138. if (!buf) {
  139. throw std::runtime_error("failed to allocate buffer for kv cache");
  140. }
  141. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  142. ggml_backend_buffer_clear(buf, 0);
  143. ctxs_bufs.emplace_back(std::move(ctx), buf);
  144. }
  145. {
  146. const size_t memory_size_k = size_k_bytes();
  147. const size_t memory_size_v = size_v_bytes();
  148. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  149. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
  150. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  151. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  152. }
  153. const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
  154. debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
  155. }
  156. void llama_kv_cache::clear(bool data) {
  157. for (uint32_t s = 0; s < n_stream; ++s) {
  158. v_cells[s].reset();
  159. v_heads[s] = 0;
  160. }
  161. if (data) {
  162. for (auto & [_, buf] : ctxs_bufs) {
  163. ggml_backend_buffer_clear(buf.get(), 0);
  164. }
  165. }
  166. }
  167. bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  168. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  169. if (p0 < 0) {
  170. p0 = 0;
  171. }
  172. if (p1 < 0) {
  173. p1 = std::numeric_limits<llama_pos>::max();
  174. }
  175. if (seq_id >= 0) {
  176. auto & cells = v_cells[seq_to_stream[seq_id]];
  177. auto & head = v_heads[seq_to_stream[seq_id]];
  178. uint32_t new_head = cells.size();
  179. for (uint32_t i = 0; i < cells.size(); ++i) {
  180. if (!cells.pos_in(i, p0, p1)) {
  181. continue;
  182. }
  183. if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
  184. if (new_head == cells.size()) {
  185. new_head = i;
  186. }
  187. }
  188. }
  189. // If we freed up a slot, set head to it so searching can start there.
  190. if (new_head != cells.size() && new_head < head) {
  191. head = new_head;
  192. }
  193. } else {
  194. // match any sequence
  195. for (uint32_t s = 0; s < n_stream; ++s) {
  196. auto & cells = v_cells[s];
  197. auto & head = v_heads[s];
  198. uint32_t new_head = cells.size();
  199. for (uint32_t i = 0; i < cells.size(); ++i) {
  200. if (!cells.pos_in(i, p0, p1)) {
  201. continue;
  202. }
  203. cells.rm(i);
  204. if (new_head == cells.size()) {
  205. new_head = i;
  206. }
  207. }
  208. // If we freed up a slot, set head to it so searching can start there.
  209. if (new_head != cells.size() && new_head < head) {
  210. head = new_head;
  211. }
  212. }
  213. }
  214. return true;
  215. }
  216. void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  217. GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
  218. GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
  219. const auto s0 = seq_to_stream[seq_id_src];
  220. const auto s1 = seq_to_stream[seq_id_dst];
  221. if (s0 == s1) {
  222. // since both sequences are in the same stream, no data copy is necessary
  223. // we just have to update the cells meta data
  224. auto & cells = v_cells[s0];
  225. if (seq_id_src == seq_id_dst) {
  226. return;
  227. }
  228. if (p0 < 0) {
  229. p0 = 0;
  230. }
  231. if (p1 < 0) {
  232. p1 = std::numeric_limits<llama_pos>::max();
  233. }
  234. for (uint32_t i = 0; i < cells.size(); ++i) {
  235. if (!cells.pos_in(i, p0, p1)) {
  236. continue;
  237. }
  238. if (cells.seq_has(i, seq_id_src)) {
  239. cells.seq_add(i, seq_id_dst);
  240. }
  241. }
  242. return;
  243. }
  244. // cross-stream sequence copies require to copy the actual buffer data
  245. bool is_full = true;
  246. if (p0 > 0 && p0 + 1 < (int) get_size()) {
  247. is_full = false;
  248. }
  249. if (p1 > 0 && p1 + 1 < (int) get_size()) {
  250. is_full = false;
  251. }
  252. GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
  253. // enqueue the copy operation - the buffer copy will be performed during the next update
  254. sc_info.ssrc.push_back(s0);
  255. sc_info.sdst.push_back(s1);
  256. v_cells[s1].reset();
  257. for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
  258. if (v_cells[s0].seq_has(i, seq_id_src)) {
  259. llama_pos pos = v_cells[s0].pos_get(i);
  260. llama_pos shift = v_cells[s0].get_shift(i);
  261. if (shift != 0) {
  262. pos -= shift;
  263. assert(pos >= 0);
  264. }
  265. v_cells[s1].pos_set(i, pos);
  266. v_cells[s1].seq_add(i, seq_id_dst);
  267. if (shift != 0) {
  268. v_cells[s1].pos_add(i, shift);
  269. }
  270. }
  271. }
  272. v_heads[s1] = v_heads[s0];
  273. //for (uint32_t s = 0; s < n_stream; ++s) {
  274. // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
  275. //}
  276. }
  277. void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
  278. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  279. auto & cells = v_cells[seq_to_stream[seq_id]];
  280. auto & head = v_heads[seq_to_stream[seq_id]];
  281. uint32_t new_head = cells.size();
  282. for (uint32_t i = 0; i < cells.size(); ++i) {
  283. if (cells.seq_keep(i, seq_id)) {
  284. if (new_head == cells.size()) {
  285. new_head = i;
  286. }
  287. }
  288. }
  289. // If we freed up a slot, set head to it so searching can start there.
  290. if (new_head != cells.size() && new_head < head) {
  291. head = new_head;
  292. }
  293. }
  294. void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
  295. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  296. auto & cells = v_cells[seq_to_stream[seq_id]];
  297. auto & head = v_heads[seq_to_stream[seq_id]];
  298. if (shift == 0) {
  299. return;
  300. }
  301. uint32_t new_head = cells.size();
  302. if (p0 < 0) {
  303. p0 = 0;
  304. }
  305. if (p1 < 0) {
  306. p1 = std::numeric_limits<llama_pos>::max();
  307. }
  308. // If there is no range then return early to avoid looping over all cells.
  309. if (p0 == p1) {
  310. return;
  311. }
  312. for (uint32_t i = 0; i < cells.size(); ++i) {
  313. if (!cells.pos_in(i, p0, p1)) {
  314. continue;
  315. }
  316. if (cells.seq_has(i, seq_id)) {
  317. if (cells.pos_add(i, shift)) {
  318. if (new_head == cells.size()) {
  319. new_head = i;
  320. }
  321. }
  322. }
  323. }
  324. // If we freed up a slot, set head to it so searching can start there.
  325. // Otherwise we just start the next search from the beginning.
  326. head = new_head != cells.size() ? new_head : 0;
  327. }
  328. void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  329. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  330. auto & cells = v_cells[seq_to_stream[seq_id]];
  331. if (d == 1) {
  332. return;
  333. }
  334. if (p0 < 0) {
  335. p0 = 0;
  336. }
  337. if (p1 < 0) {
  338. p1 = std::numeric_limits<llama_pos>::max();
  339. }
  340. // If there is no range then return early to avoid looping over the cache.
  341. if (p0 == p1) {
  342. return;
  343. }
  344. for (uint32_t i = 0; i < cells.size(); ++i) {
  345. if (!cells.pos_in(i, p0, p1)) {
  346. continue;
  347. }
  348. if (cells.seq_has(i, seq_id)) {
  349. cells.pos_div(i, d);
  350. }
  351. }
  352. }
  353. llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
  354. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  355. const auto & cells = v_cells[seq_to_stream[seq_id]];
  356. return cells.seq_pos_min(seq_id);
  357. }
  358. llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
  359. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  360. const auto & cells = v_cells[seq_to_stream[seq_id]];
  361. return cells.seq_pos_max(seq_id);
  362. }
  363. std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
  364. std::map<ggml_backend_buffer_type_t, size_t> ret;
  365. for (const auto & [_, buf] : ctxs_bufs) {
  366. ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
  367. }
  368. return ret;
  369. }
  370. llama_memory_context_ptr llama_kv_cache::init_batch(
  371. llama_batch_allocr & balloc,
  372. uint32_t n_ubatch,
  373. bool embd_all) {
  374. GGML_UNUSED(embd_all);
  375. do {
  376. balloc.split_reset();
  377. std::vector<llama_ubatch> ubatches;
  378. while (true) {
  379. auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
  380. if (ubatch.n_tokens == 0) {
  381. break;
  382. }
  383. ubatches.push_back(std::move(ubatch)); // NOLINT
  384. }
  385. if (balloc.get_n_used() < balloc.get_n_tokens()) {
  386. // failed to find a suitable split
  387. break;
  388. }
  389. auto sinfos = prepare(ubatches);
  390. if (sinfos.empty()) {
  391. break;
  392. }
  393. return std::make_unique<llama_kv_cache_context>(
  394. this, std::move(sinfos), std::move(ubatches));
  395. } while (false);
  396. return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
  397. }
  398. llama_memory_context_ptr llama_kv_cache::init_full() {
  399. return std::make_unique<llama_kv_cache_context>(this);
  400. }
  401. llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
  402. GGML_UNUSED(optimize);
  403. bool do_shift = get_has_shift();
  404. return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
  405. }
  406. llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
  407. llama_kv_cache::slot_info_vec_t res;
  408. struct state_t {
  409. slot_info sinfo; // slot info for the ubatch
  410. std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
  411. std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
  412. };
  413. // remember the old state of the cells so we can restore it in the end
  414. std::vector<state_t> states;
  415. bool success = true;
  416. for (const auto & ubatch : ubatches) {
  417. // only find a suitable slot for the ubatch. don't modify the cells yet
  418. const auto sinfo_new = find_slot(ubatch, false);
  419. if (sinfo_new.empty()) {
  420. success = false;
  421. break;
  422. }
  423. // remeber the position that we found
  424. res.push_back(sinfo_new);
  425. // store the old state of the cells in the recovery stack
  426. {
  427. state_t state = { sinfo_new, v_heads, {} };
  428. for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
  429. auto & cells = v_cells[sinfo_new.strm[s]];
  430. state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
  431. }
  432. states.push_back(std::move(state));
  433. }
  434. // now emplace the ubatch
  435. apply_ubatch(sinfo_new, ubatch);
  436. }
  437. GGML_ASSERT(!states.empty() || !success);
  438. // iterate backwards and restore the cells to their original state
  439. for (auto it = states.rbegin(); it != states.rend(); ++it) {
  440. const auto & sinfo = it->sinfo;
  441. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  442. auto & cells = v_cells[sinfo.strm[s]];
  443. auto & head = v_heads[sinfo.strm[s]];
  444. cells.set(sinfo.idxs[s], it->v_cells[s]);
  445. head = it->v_heads_old[s];
  446. }
  447. }
  448. if (!success) {
  449. return {};
  450. }
  451. return res;
  452. }
  453. bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
  454. bool updated = false;
  455. auto * sched = lctx->get_sched();
  456. if (!sc_info.empty()) {
  457. assert(n_stream > 1 && "stream copy should never happen with a single stream");
  458. llama_synchronize(lctx);
  459. const size_t n_copy = sc_info.ssrc.size();
  460. for (size_t i = 0; i < n_copy; ++i) {
  461. const auto ssrc = sc_info.ssrc[i];
  462. const auto sdst = sc_info.sdst[i];
  463. assert(ssrc < n_stream);
  464. assert(sdst < n_stream);
  465. LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
  466. assert(ssrc != sdst);
  467. for (uint32_t il = 0; il < layers.size(); ++il) {
  468. const auto & layer = layers[il];
  469. ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
  470. ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
  471. }
  472. }
  473. }
  474. if (do_shift) {
  475. if (!get_can_shift()) {
  476. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  477. }
  478. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  479. // apply K-shift if needed
  480. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  481. ggml_backend_sched_reset(sched);
  482. auto * res = lctx->get_gf_res_reserve();
  483. res->reset();
  484. auto * gf = build_graph_shift(res, lctx);
  485. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  486. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
  487. return updated;
  488. }
  489. res->set_inputs(nullptr);
  490. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  491. LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
  492. return updated;
  493. }
  494. updated = true;
  495. }
  496. for (uint32_t s = 0; s < n_stream; ++s) {
  497. auto & cells = v_cells[s];
  498. cells.reset_shift();
  499. }
  500. }
  501. return updated;
  502. }
  503. llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
  504. if (debug > 0) {
  505. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  506. const auto seq_id = ubatch.seq_id_unq[s];
  507. const auto stream_id = seq_to_stream[seq_id];
  508. const auto & cells = v_cells[stream_id];
  509. const uint32_t head_cur = v_heads[stream_id];
  510. LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
  511. __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
  512. if ((debug == 2 && n_swa > 0) || debug > 2) {
  513. std::string ss;
  514. for (uint32_t i = 0; i < cells.size(); ++i) {
  515. if (cells.is_empty(i)) {
  516. ss += '.';
  517. } else {
  518. assert(cells.seq_count(i) >= 1);
  519. if (cells.seq_count(i) == 1) {
  520. ss += std::to_string(cells.seq_get(i));
  521. } else {
  522. ss += 'M';
  523. }
  524. }
  525. if (i%256 == 255) {
  526. ss += " *";
  527. ss += '\n';
  528. }
  529. }
  530. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  531. }
  532. if ((debug == 2 && n_swa > 0) || debug > 2) {
  533. std::string ss;
  534. for (uint32_t i = 0; i < cells.size(); ++i) {
  535. std::string cur;
  536. if (cells.is_empty(i)) {
  537. cur = '.';
  538. } else {
  539. cur = std::to_string(cells.pos_get(i));
  540. }
  541. const int n = cur.size();
  542. for (int j = 0; j < 5 - n; ++j) {
  543. cur += ' ';
  544. }
  545. ss += cur;
  546. if (i%256 == 255) {
  547. ss += " *";
  548. }
  549. if (i%64 == 63) {
  550. ss += '\n';
  551. }
  552. }
  553. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  554. }
  555. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  556. if (cells.seq_pos_min(s) < 0) {
  557. continue;
  558. }
  559. LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
  560. }
  561. }
  562. }
  563. uint32_t n_tokens = ubatch.n_tokens;
  564. uint32_t n_seqs = 1;
  565. if (n_stream > 1) {
  566. GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
  567. n_seqs = ubatch.n_seqs_unq;
  568. n_tokens = n_tokens / n_seqs;
  569. }
  570. slot_info res = {
  571. /*.s0 =*/ LLAMA_MAX_SEQ,
  572. /*.s1 =*/ 0,
  573. /*.strm =*/ { },
  574. /*.idxs =*/ { },
  575. };
  576. res.resize(n_seqs);
  577. for (uint32_t s = 0; s < n_seqs; ++s) {
  578. const auto seq_id = ubatch.seq_id_unq[s];
  579. if (n_stream > 1) {
  580. GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
  581. GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
  582. }
  583. res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
  584. res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
  585. res.strm[s] = seq_to_stream[seq_id];
  586. res.idxs[s].reserve(n_tokens);
  587. const auto & cells = v_cells[seq_to_stream[seq_id]];
  588. uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
  589. // if we have enough unused cells before the current head ->
  590. // better to start searching from the beginning of the cache, hoping to fill it
  591. if (head_cur > cells.get_used() + 2*n_tokens) {
  592. head_cur = 0;
  593. }
  594. if (n_tokens > cells.size()) {
  595. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
  596. return { };
  597. }
  598. uint32_t n_tested = 0;
  599. // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
  600. // for non-continuous slots, we test the tokens one by one
  601. const uint32_t n_test = cont ? n_tokens : 1;
  602. while (true) {
  603. if (head_cur + n_test > cells.size()) {
  604. n_tested += cells.size() - head_cur;
  605. head_cur = 0;
  606. continue;
  607. }
  608. for (uint32_t i = 0; i < n_test; i++) {
  609. const auto idx = head_cur;
  610. head_cur++;
  611. n_tested++;
  612. //const llama_pos pos = ubatch.pos[i];
  613. //const llama_seq_id seq_id = ubatch.seq_id[i][0];
  614. // can we use this cell? either:
  615. // - the cell is empty
  616. // - the cell is occupied only by one sequence:
  617. // - (disabled) mask causally, if the sequence is the same as the one we are inserting
  618. // - mask SWA, using current max pos for that sequence in the cache
  619. // always insert in the cell with minimum pos
  620. bool can_use = cells.is_empty(idx);
  621. if (!can_use && cells.seq_count(idx) == 1) {
  622. const llama_pos pos_cell = cells.pos_get(idx);
  623. // (disabled) causal mask
  624. // note: it's better to purge any "future" tokens beforehand
  625. //if (cells.seq_has(idx, seq_id)) {
  626. // can_use = pos_cell >= pos;
  627. //}
  628. if (!can_use) {
  629. const llama_seq_id seq_id_cell = cells.seq_get(idx);
  630. // SWA mask
  631. if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
  632. can_use = true;
  633. }
  634. }
  635. }
  636. if (can_use) {
  637. res.idxs[s].push_back(idx);
  638. } else {
  639. if (cont) {
  640. break;
  641. }
  642. }
  643. }
  644. if (res.idxs[s].size() == n_tokens) {
  645. break;
  646. }
  647. if (cont) {
  648. res.idxs[s].clear();
  649. }
  650. if (n_tested >= cells.size()) {
  651. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  652. return { };
  653. }
  654. }
  655. // we didn't find a suitable slot - return empty result
  656. if (res.idxs[s].size() < n_tokens) {
  657. return { };
  658. }
  659. }
  660. assert(res.s1 >= res.s0);
  661. return res;
  662. }
  663. void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
  664. // keep track of the max sequence position that we would overwrite with this ubatch
  665. // for non-SWA cache, this would be always empty
  666. llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
  667. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  668. seq_pos_max_rm[s] = -1;
  669. }
  670. assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
  671. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  672. for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
  673. const uint32_t i = s*sinfo.size() + ii;
  674. auto & cells = v_cells[sinfo.strm[s]];
  675. const auto idx = sinfo.idxs[s][ii];
  676. if (!cells.is_empty(idx)) {
  677. assert(cells.seq_count(idx) == 1);
  678. const llama_seq_id seq_id = cells.seq_get(idx);
  679. const llama_pos pos = cells.pos_get(idx);
  680. seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
  681. cells.rm(idx);
  682. }
  683. cells.pos_set(idx, ubatch.pos[i]);
  684. for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
  685. cells.seq_add(idx, ubatch.seq_id[i][s]);
  686. }
  687. }
  688. }
  689. // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
  690. // will be present in the cache. so we have to purge any position which is less than those we would overwrite
  691. // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
  692. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  693. if (seq_pos_max_rm[s] == -1) {
  694. continue;
  695. }
  696. GGML_ASSERT(s < seq_to_stream.size());
  697. auto & cells = v_cells[seq_to_stream[s]];
  698. if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
  699. LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
  700. __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
  701. seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
  702. }
  703. }
  704. // move the head at the end of the slot
  705. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  706. auto & head = v_heads[sinfo.strm[s]];
  707. head = sinfo.idxs[s].back() + 1;
  708. }
  709. }
  710. bool llama_kv_cache::get_can_shift() const {
  711. return true;
  712. }
  713. uint32_t llama_kv_cache::get_size() const {
  714. const auto & cells = v_cells[seq_to_stream[0]];
  715. return cells.size();
  716. }
  717. uint32_t llama_kv_cache::get_n_stream() const {
  718. return n_stream;
  719. }
  720. bool llama_kv_cache::get_has_shift() const {
  721. bool result = false;
  722. for (uint32_t s = 0; s < n_stream; ++s) {
  723. result |= v_cells[s].get_has_shift();
  724. }
  725. return result;
  726. }
  727. uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
  728. uint32_t result = 0;
  729. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  730. const auto & cells = v_cells[sinfo.strm[s]];
  731. result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
  732. }
  733. return result;
  734. }
  735. ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  736. const int32_t ikv = map_layer_ids.at(il);
  737. auto * k = layers[ikv].k;
  738. const uint64_t kv_size = get_size();
  739. const uint64_t n_embd_k_gqa = k->ne[0];
  740. assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
  741. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  742. return ggml_view_4d(ctx, k,
  743. hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
  744. ggml_row_size(k->type, hparams.n_embd_head_k),
  745. ggml_row_size(k->type, n_embd_k_gqa),
  746. ggml_row_size(k->type, n_embd_k_gqa*kv_size),
  747. ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
  748. }
  749. ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  750. const int32_t ikv = map_layer_ids.at(il);
  751. auto * v = layers[ikv].v;
  752. const uint64_t kv_size = get_size();
  753. const uint64_t n_embd_v_gqa = v->ne[0];
  754. // [TAG_V_CACHE_VARIABLE]
  755. assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
  756. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  757. if (!v_trans) {
  758. // note: v->nb[1] <= v->nb[2]
  759. return ggml_view_4d(ctx, v,
  760. hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
  761. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  762. ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
  763. ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
  764. ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
  765. }
  766. // note: v->nb[1] > v->nb[2]
  767. return ggml_view_4d(ctx, v,
  768. n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
  769. ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
  770. ggml_row_size(v->type, kv_size), // v->nb[2]
  771. ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
  772. ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
  773. }
  774. ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
  775. GGML_UNUSED(sinfo);
  776. const int32_t ikv = map_layer_ids.at(il);
  777. ggml_tensor * k = layers[ikv].k;
  778. const int64_t n_embd_head = k_cur->ne[0];
  779. const int64_t n_head = k_cur->ne[1];
  780. const int64_t n_tokens = k_cur->ne[2];
  781. const int64_t n_embd_gqa = n_embd_head*n_head;
  782. // we can merge dims 0 and 1
  783. // TODO: add ggml helper function for this?
  784. GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
  785. k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
  786. const int64_t n_stream = k->ne[2];
  787. if (n_stream > 1) {
  788. const int64_t kv_size = get_size();
  789. assert(n_embd_gqa == k->ne[0]);
  790. assert(kv_size == k->ne[1]);
  791. // merge the buffer across all streams because the idxs are global
  792. k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
  793. }
  794. // store the current K values into the cache
  795. return ggml_set_rows(ctx, k, k_cur, k_idxs);
  796. }
  797. ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
  798. GGML_UNUSED(sinfo);
  799. const int32_t ikv = map_layer_ids.at(il);
  800. auto * v = layers[ikv].v;
  801. const int64_t n_embd_head = v_cur->ne[0];
  802. const int64_t n_head = v_cur->ne[1];
  803. const int64_t n_tokens = v_cur->ne[2];
  804. const int64_t n_embd_gqa = n_embd_head*n_head;
  805. // we can merge dims 0 and 1
  806. GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
  807. const int64_t n_stream = v->ne[2];
  808. // take this branch when FA is enabled (the V cache is not transposed)
  809. if (!v_trans) {
  810. v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
  811. if (n_stream > 1) {
  812. const int64_t kv_size = get_size();
  813. assert(n_embd_gqa == v->ne[0]);
  814. assert(kv_size == v->ne[1]);
  815. // merge the buffer across all streams because the idxs are global
  816. v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
  817. }
  818. return ggml_set_rows(ctx, v, v_cur, v_idxs);
  819. }
  820. if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
  821. // we can merge dims 0, 1 and 2
  822. v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
  823. } else {
  824. // otherwise -> make a copy to get contiguous data
  825. v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
  826. }
  827. // [TAG_V_CACHE_VARIABLE]
  828. if (n_embd_gqa < v->ne[0]) {
  829. v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
  830. }
  831. // in this branch the v_idxs are constructed in such a way that each row is a single head element
  832. ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
  833. v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
  834. return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
  835. }
  836. ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  837. const uint32_t n_tokens = ubatch.n_tokens;
  838. ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  839. ggml_set_input(k_idxs);
  840. return k_idxs;
  841. }
  842. ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  843. const uint32_t n_tokens = ubatch.n_tokens;
  844. ggml_tensor * v_idxs;
  845. if (!v_trans) {
  846. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  847. } else {
  848. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
  849. }
  850. ggml_set_input(v_idxs);
  851. return v_idxs;
  852. }
  853. void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  854. const uint32_t n_tokens = ubatch->n_tokens;
  855. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  856. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  857. int64_t * data = (int64_t *) dst->data;
  858. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  859. const int64_t offs = sinfo.strm[s]*get_size();
  860. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  861. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  862. }
  863. }
  864. }
  865. void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  866. const uint32_t n_tokens = ubatch->n_tokens;
  867. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  868. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  869. int64_t * data = (int64_t *) dst->data;
  870. if (!v_trans) {
  871. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  872. const int64_t offs = sinfo.strm[s]*get_size();
  873. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  874. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  875. }
  876. }
  877. } else {
  878. // note: the V cache is transposed when not using flash attention
  879. const int64_t kv_size = get_size();
  880. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
  881. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  882. const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
  883. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  884. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  885. data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
  886. }
  887. }
  888. }
  889. }
  890. }
  891. void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
  892. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  893. int32_t * data = (int32_t *) dst->data;
  894. for (uint32_t s = 0; s < n_stream; ++s) {
  895. const auto & cells = v_cells[s];
  896. for (uint32_t i = 0; i < cells.size(); ++i) {
  897. data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
  898. }
  899. }
  900. }
  901. void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  902. const uint32_t n_tokens = ubatch->n_tokens;
  903. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  904. float * data = (float *) dst->data;
  905. const int64_t n_kv = dst->ne[0];
  906. const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
  907. GGML_ASSERT(n_tokens%n_stream == 0);
  908. // n_tps == n_tokens_per_stream
  909. const int64_t n_tps = n_tokens/n_stream;
  910. const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
  911. std::fill(data, data + ggml_nelements(dst), -INFINITY);
  912. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  913. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  914. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  915. // Causal mask:
  916. // xxx-------
  917. // xxxx------
  918. // xxxxx-----
  919. // Non-causal mask:
  920. // xxxxx-----
  921. // xxxxx-----
  922. // xxxxx-----
  923. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  924. // TODO: optimize this section
  925. for (uint32_t h = 0; h < 1; ++h) {
  926. for (uint32_t s = 0; s < n_stream; ++s) {
  927. for (uint32_t ii = 0; ii < n_tps; ++ii) {
  928. const uint32_t i = s*n_tps + ii;
  929. const llama_seq_id seq_id = ubatch->seq_id[i][0];
  930. const auto & cells = v_cells[seq_to_stream[seq_id]];
  931. const llama_pos p1 = ubatch->pos[i];
  932. const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
  933. for (uint32_t j = 0; j < n_kv; ++j) {
  934. if (cells.is_empty(j)) {
  935. continue;
  936. }
  937. // mask the token if not the same sequence
  938. if (!cells.seq_has(j, seq_id)) {
  939. continue;
  940. }
  941. const llama_pos p0 = cells.pos_get(j);
  942. // mask future tokens
  943. if (causal_attn && p0 > p1) {
  944. continue;
  945. }
  946. // apply SWA if any
  947. if (is_masked_swa(p0, p1)) {
  948. continue;
  949. }
  950. data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
  951. }
  952. }
  953. }
  954. }
  955. }
  956. void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  957. const int64_t n_tokens = ubatch->n_tokens;
  958. GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
  959. const auto & cells = v_cells[0];
  960. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  961. GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
  962. int32_t * data = (int32_t *) dst->data;
  963. const int32_t n_kv = dst->ne[0];
  964. for (int h = 0; h < 1; ++h) {
  965. for (int i = 0; i < n_tokens; ++i) {
  966. for (int j = 0; j < n_kv; ++j) {
  967. // the position when the cells is empty is irrelevant - it will be masked out later in the attention
  968. const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
  969. data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
  970. }
  971. }
  972. }
  973. }
  974. size_t llama_kv_cache::total_size() const {
  975. size_t size = 0;
  976. for (const auto & [_, buf] : ctxs_bufs) {
  977. size += ggml_backend_buffer_get_size(buf.get());
  978. }
  979. return size;
  980. }
  981. size_t llama_kv_cache::size_k_bytes() const {
  982. size_t size_k_bytes = 0;
  983. for (const auto & layer : layers) {
  984. size_k_bytes += ggml_nbytes(layer.k);
  985. }
  986. return size_k_bytes;
  987. }
  988. size_t llama_kv_cache::size_v_bytes() const {
  989. size_t size_v_bytes = 0;
  990. for (const auto & layer : layers) {
  991. size_v_bytes += ggml_nbytes(layer.v);
  992. }
  993. return size_v_bytes;
  994. }
  995. ggml_tensor * llama_kv_cache::build_rope_shift(
  996. const llama_cparams & cparams,
  997. ggml_context * ctx,
  998. ggml_tensor * cur,
  999. ggml_tensor * shift,
  1000. ggml_tensor * factors,
  1001. float freq_base,
  1002. float freq_scale) const {
  1003. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  1004. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  1005. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  1006. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  1007. const auto & n_rot = hparams.n_rot;
  1008. const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
  1009. // @ngxson : this is a workaround
  1010. // for M-RoPE, we want to rotate the whole vector when doing KV shift
  1011. // a normal RoPE should work, we just need to use the correct ordering
  1012. // ref: https://github.com/ggml-org/llama.cpp/pull/13870
  1013. ? LLAMA_ROPE_TYPE_NEOX
  1014. : hparams.rope_type;
  1015. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  1016. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  1017. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
  1018. ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
  1019. : cparams.yarn_attn_factor;
  1020. ggml_tensor * tmp;
  1021. if (ggml_is_quantized(cur->type)) {
  1022. // dequantize to f32 -> RoPE -> quantize back
  1023. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  1024. tmp = ggml_rope_ext(ctx, tmp,
  1025. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1026. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1027. tmp = ggml_cpy(ctx, tmp, cur);
  1028. } else {
  1029. // we rotate only the first n_rot dimensions
  1030. tmp = ggml_rope_ext_inplace(ctx, cur,
  1031. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1032. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1033. }
  1034. return tmp;
  1035. }
  1036. class llm_graph_input_k_shift : public llm_graph_input_i {
  1037. public:
  1038. llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
  1039. virtual ~llm_graph_input_k_shift() = default;
  1040. void set_input(const llama_ubatch * ubatch) override;
  1041. ggml_tensor * k_shift; // I32 [kv_size*n_stream]
  1042. const llama_kv_cache * kv_self;
  1043. };
  1044. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  1045. GGML_UNUSED(ubatch);
  1046. if (k_shift) {
  1047. kv_self->set_input_k_shift(k_shift);
  1048. }
  1049. }
  1050. ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
  1051. auto * ctx = res->get_ctx();
  1052. auto * gf = res->get_gf();
  1053. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1054. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  1055. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  1056. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
  1057. ggml_set_input(inp->k_shift);
  1058. const auto & cparams = lctx->get_cparams();
  1059. for (const auto & layer : layers) {
  1060. const uint32_t il = layer.il;
  1061. const int64_t n_head_kv = hparams.n_head_kv(il);
  1062. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1063. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  1064. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  1065. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  1066. ggml_tensor * k =
  1067. ggml_view_3d(ctx, layer.k,
  1068. n_embd_head_k, n_head_kv, get_size()*n_stream,
  1069. ggml_row_size(layer.k->type, n_embd_head_k),
  1070. ggml_row_size(layer.k->type, n_embd_k_gqa),
  1071. 0);
  1072. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  1073. ggml_build_forward_expand(gf, cur);
  1074. }
  1075. res->add_input(std::move(inp));
  1076. return gf;
  1077. }
  1078. bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
  1079. return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
  1080. }
  1081. void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
  1082. GGML_UNUSED(flags);
  1083. io.write(&n_stream, sizeof(n_stream));
  1084. for (uint32_t s = 0; s < n_stream; ++s) {
  1085. cell_ranges_t cr { s, {} };
  1086. uint32_t cell_count = 0;
  1087. const auto & cells = v_cells[s];
  1088. // Count the number of cells with the specified seq_id
  1089. // Find all the ranges of cells with this seq id (or all, when -1)
  1090. uint32_t cell_range_begin = cells.size();
  1091. for (uint32_t i = 0; i < cells.size(); ++i) {
  1092. if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
  1093. ++cell_count;
  1094. if (cell_range_begin == cells.size()) {
  1095. cell_range_begin = i;
  1096. }
  1097. } else {
  1098. if (cell_range_begin != cells.size()) {
  1099. cr.data.emplace_back(cell_range_begin, i);
  1100. cell_range_begin = cells.size();
  1101. }
  1102. }
  1103. }
  1104. if (cell_range_begin != cells.size()) {
  1105. cr.data.emplace_back(cell_range_begin, cells.size());
  1106. }
  1107. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1108. uint32_t cell_count_check = 0;
  1109. for (const auto & range : cr.data) {
  1110. cell_count_check += range.second - range.first;
  1111. }
  1112. GGML_ASSERT(cell_count == cell_count_check);
  1113. io.write(&cell_count, sizeof(cell_count));
  1114. // skip empty streams
  1115. if (cell_count == 0) {
  1116. continue;
  1117. }
  1118. state_write_meta(io, cr, seq_id);
  1119. state_write_data(io, cr);
  1120. }
  1121. }
  1122. void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
  1123. GGML_UNUSED(flags);
  1124. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  1125. uint32_t n_stream_cur;
  1126. io.read_to(&n_stream_cur, sizeof(n_stream_cur));
  1127. if (n_stream_cur != n_stream) {
  1128. throw std::runtime_error("n_stream mismatch");
  1129. }
  1130. for (uint32_t s = 0; s < n_stream; ++s) {
  1131. uint32_t cell_count;
  1132. io.read_to(&cell_count, sizeof(cell_count));
  1133. if (cell_count == 0) {
  1134. continue;
  1135. }
  1136. const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
  1137. bool res = true;
  1138. res = res && state_read_meta(io, strm, cell_count, seq_id);
  1139. res = res && state_read_data(io, strm, cell_count);
  1140. if (!res) {
  1141. if (seq_id == -1) {
  1142. clear(true);
  1143. } else {
  1144. seq_rm(seq_id, -1, -1);
  1145. }
  1146. throw std::runtime_error("failed to restore kv cache");
  1147. }
  1148. }
  1149. }
  1150. void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
  1151. const auto & cells = v_cells[cr.strm];
  1152. for (const auto & range : cr.data) {
  1153. for (uint32_t i = range.first; i < range.second; ++i) {
  1154. std::vector<llama_seq_id> seq_ids;
  1155. for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
  1156. if (cur == seq_id || seq_id == -1) {
  1157. if (cells.seq_has(i, cur)) {
  1158. seq_ids.push_back(cur);
  1159. }
  1160. }
  1161. }
  1162. const llama_pos pos = cells.pos_get(i);
  1163. const uint32_t n_seq_id = seq_ids.size();
  1164. io.write(&pos, sizeof(pos));
  1165. io.write(&n_seq_id, sizeof(n_seq_id));
  1166. for (const auto & seq_id : seq_ids) {
  1167. io.write(&seq_id, sizeof(seq_id));
  1168. }
  1169. }
  1170. }
  1171. }
  1172. void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
  1173. const auto & cells = v_cells[cr.strm];
  1174. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1175. const uint32_t n_layer = layers.size();
  1176. io.write(&v_trans, sizeof(v_trans));
  1177. io.write(&n_layer, sizeof(n_layer));
  1178. std::vector<uint8_t> tmp_buf;
  1179. // Iterate and write all the keys first, each row is a cell
  1180. // Get whole range at a time
  1181. for (const auto & layer : layers) {
  1182. const uint32_t il = layer.il;
  1183. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1184. auto * k = layer.k_stream[cr.strm];
  1185. // Write key type
  1186. const int32_t k_type_i = (int32_t) k->type;
  1187. io.write(&k_type_i, sizeof(k_type_i));
  1188. // Write row size of key
  1189. const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1190. io.write(&k_size_row, sizeof(k_size_row));
  1191. // Read each range of cells of k_size length each into tmp_buf and write out
  1192. for (const auto & range : cr.data) {
  1193. const size_t range_size = range.second - range.first;
  1194. const size_t buf_size = range_size * k_size_row;
  1195. io.write_tensor(k, range.first * k_size_row, buf_size);
  1196. }
  1197. }
  1198. if (!v_trans) {
  1199. for (const auto & layer : layers) {
  1200. const uint32_t il = layer.il;
  1201. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1202. auto * v = layer.v_stream[cr.strm];
  1203. // Write value type
  1204. const int32_t v_type_i = (int32_t) v->type;
  1205. io.write(&v_type_i, sizeof(v_type_i));
  1206. // Write row size of value
  1207. const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1208. io.write(&v_size_row, sizeof(v_size_row));
  1209. // Read each range of cells of v_size length each into tmp_buf and write out
  1210. for (const auto & range : cr.data) {
  1211. const size_t range_size = range.second - range.first;
  1212. const size_t buf_size = range_size * v_size_row;
  1213. io.write_tensor(v, range.first * v_size_row, buf_size);
  1214. }
  1215. }
  1216. } else {
  1217. // When v is transposed, we also need the element size and get the element ranges from each row
  1218. const uint32_t kv_size = cells.size();
  1219. for (const auto & layer : layers) {
  1220. const uint32_t il = layer.il;
  1221. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1222. auto * v = layer.v_stream[cr.strm];
  1223. // Write value type
  1224. const int32_t v_type_i = (int32_t) v->type;
  1225. io.write(&v_type_i, sizeof(v_type_i));
  1226. // Write element size
  1227. const uint32_t v_size_el = ggml_type_size(v->type);
  1228. io.write(&v_size_el, sizeof(v_size_el));
  1229. // Write GQA embedding size
  1230. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1231. // For each row, we get the element values of each cell
  1232. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1233. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1234. for (const auto & range : cr.data) {
  1235. const size_t range_size = range.second - range.first;
  1236. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1237. const size_t buf_size = range_size * v_size_el;
  1238. io.write_tensor(v, src_offset, buf_size);
  1239. }
  1240. }
  1241. }
  1242. }
  1243. }
  1244. bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1245. auto & cells = v_cells[strm];
  1246. auto & head = v_heads[strm];
  1247. if (dest_seq_id != -1) {
  1248. // single sequence
  1249. seq_rm(dest_seq_id, -1, -1);
  1250. llama_batch_allocr balloc(hparams.n_pos_per_embd());
  1251. llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
  1252. ubatch.seq_id_unq[0] = dest_seq_id;
  1253. for (uint32_t i = 0; i < cell_count; ++i) {
  1254. llama_pos pos;
  1255. uint32_t n_seq_id;
  1256. io.read_to(&pos, sizeof(pos));
  1257. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1258. if (n_seq_id != 1) {
  1259. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1260. return false;
  1261. }
  1262. // read the sequence id, but directly discard it - we will use dest_seq_id instead
  1263. {
  1264. llama_seq_id seq_id;
  1265. io.read_to(&seq_id, sizeof(seq_id));
  1266. }
  1267. ubatch.pos[i] = pos;
  1268. ubatch.n_seq_id[i] = n_seq_id;
  1269. ubatch.seq_id[i] = &dest_seq_id;
  1270. }
  1271. const auto sinfo = find_slot(ubatch, true);
  1272. if (sinfo.empty()) {
  1273. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1274. return false;
  1275. }
  1276. apply_ubatch(sinfo, ubatch);
  1277. const auto head_cur = sinfo.head();
  1278. // keep the head at the old position because we will read the KV data into it in state_read_data()
  1279. head = head_cur;
  1280. LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
  1281. // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1282. // Assume that this is one contiguous block of cells
  1283. GGML_ASSERT(head_cur + cell_count <= cells.size());
  1284. GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
  1285. GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
  1286. GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
  1287. GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
  1288. } else {
  1289. // whole KV cache restore
  1290. if (cell_count > cells.size()) {
  1291. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1292. return false;
  1293. }
  1294. clear(true);
  1295. for (uint32_t i = 0; i < cell_count; ++i) {
  1296. llama_pos pos;
  1297. uint32_t n_seq_id;
  1298. io.read_to(&pos, sizeof(pos));
  1299. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1300. cells.pos_set(i, pos);
  1301. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1302. llama_seq_id seq_id;
  1303. io.read_to(&seq_id, sizeof(seq_id));
  1304. if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
  1305. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
  1306. return false;
  1307. }
  1308. cells.seq_add(i, seq_id);
  1309. }
  1310. }
  1311. head = 0;
  1312. }
  1313. return true;
  1314. }
  1315. bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
  1316. auto & cells = v_cells[strm];
  1317. auto & head = v_heads[strm];
  1318. uint32_t v_trans;
  1319. uint32_t n_layer;
  1320. io.read_to(&v_trans, sizeof(v_trans));
  1321. io.read_to(&n_layer, sizeof(n_layer));
  1322. if (n_layer != layers.size()) {
  1323. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1324. return false;
  1325. }
  1326. if (cell_count > cells.size()) {
  1327. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
  1328. return false;
  1329. }
  1330. if (this->v_trans != (bool) v_trans) {
  1331. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1332. return false;
  1333. }
  1334. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1335. for (const auto & layer : layers) {
  1336. const uint32_t il = layer.il;
  1337. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1338. auto * k = layer.k_stream[strm];
  1339. // Read type of key
  1340. int32_t k_type_i_ref;
  1341. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1342. const int32_t k_type_i = (int32_t) k->type;
  1343. if (k_type_i != k_type_i_ref) {
  1344. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1345. return false;
  1346. }
  1347. // Read row size of key
  1348. uint64_t k_size_row_ref;
  1349. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1350. const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1351. if (k_size_row != k_size_row_ref) {
  1352. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1353. return false;
  1354. }
  1355. if (cell_count) {
  1356. // Read and set the keys for the whole cell range
  1357. ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1358. }
  1359. }
  1360. if (!this->v_trans) {
  1361. for (const auto & layer : layers) {
  1362. const uint32_t il = layer.il;
  1363. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1364. auto * v = layer.v_stream[strm];
  1365. // Read type of value
  1366. int32_t v_type_i_ref;
  1367. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1368. const int32_t v_type_i = (int32_t) v->type;
  1369. if (v_type_i != v_type_i_ref) {
  1370. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1371. return false;
  1372. }
  1373. // Read row size of value
  1374. uint64_t v_size_row_ref;
  1375. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1376. const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1377. if (v_size_row != v_size_row_ref) {
  1378. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1379. return false;
  1380. }
  1381. if (cell_count) {
  1382. // Read and set the values for the whole cell range
  1383. ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1384. }
  1385. }
  1386. } else {
  1387. // For each layer, read the values for each cell (transposed)
  1388. for (const auto & layer : layers) {
  1389. const uint32_t il = layer.il;
  1390. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1391. auto * v = layer.v_stream[strm];
  1392. // Read type of value
  1393. int32_t v_type_i_ref;
  1394. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1395. const int32_t v_type_i = (int32_t) v->type;
  1396. if (v_type_i != v_type_i_ref) {
  1397. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1398. return false;
  1399. }
  1400. // Read element size of value
  1401. uint32_t v_size_el_ref;
  1402. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1403. const size_t v_size_el = ggml_type_size(v->type);
  1404. if (v_size_el != v_size_el_ref) {
  1405. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1406. return false;
  1407. }
  1408. // Read GQA embedding size
  1409. uint32_t n_embd_v_gqa_ref;
  1410. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1411. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1412. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1413. return false;
  1414. }
  1415. if (cell_count) {
  1416. // For each row in the transposed matrix, read the values for the whole cell range
  1417. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1418. const size_t dst_offset = (head + j * cells.size()) * v_size_el;
  1419. ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1420. }
  1421. }
  1422. }
  1423. }
  1424. return true;
  1425. }
  1426. //
  1427. // llama_kv_cache_context
  1428. //
  1429. llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
  1430. llama_kv_cache_context::llama_kv_cache_context(
  1431. llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
  1432. n_kv = kv->get_size();
  1433. const uint32_t n_stream = kv->get_n_stream();
  1434. // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
  1435. sinfos.resize(1);
  1436. sinfos[0].s0 = 0;
  1437. sinfos[0].s1 = n_stream - 1;
  1438. sinfos[0].idxs.resize(n_stream);
  1439. for (uint32_t s = 0; s < n_stream; ++s) {
  1440. sinfos[0].strm.push_back(s);
  1441. sinfos[0].idxs[s].resize(1, 0);
  1442. }
  1443. }
  1444. llama_kv_cache_context::llama_kv_cache_context(
  1445. llama_kv_cache * kv,
  1446. llama_context * lctx,
  1447. bool do_shift,
  1448. stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
  1449. if (!do_shift && this->sc_info.empty()) {
  1450. status = LLAMA_MEMORY_STATUS_NO_UPDATE;
  1451. }
  1452. }
  1453. llama_kv_cache_context::llama_kv_cache_context(
  1454. llama_kv_cache * kv,
  1455. llama_kv_cache::slot_info_vec_t sinfos,
  1456. std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
  1457. }
  1458. llama_kv_cache_context::~llama_kv_cache_context() = default;
  1459. bool llama_kv_cache_context::next() {
  1460. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1461. if (++i_cur >= ubatches.size()) {
  1462. return false;
  1463. }
  1464. return true;
  1465. }
  1466. bool llama_kv_cache_context::apply() {
  1467. assert(!llama_memory_status_is_fail(status));
  1468. // no ubatches -> this is a KV cache update
  1469. if (ubatches.empty()) {
  1470. kv->update(lctx, do_shift, sc_info);
  1471. return true;
  1472. }
  1473. kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
  1474. n_kv = kv->get_n_kv(sinfos[i_cur]);
  1475. return true;
  1476. }
  1477. llama_memory_status llama_kv_cache_context::get_status() const {
  1478. return status;
  1479. }
  1480. const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
  1481. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1482. return ubatches[i_cur];
  1483. }
  1484. uint32_t llama_kv_cache_context::get_n_kv() const {
  1485. return n_kv;
  1486. }
  1487. ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
  1488. return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
  1489. }
  1490. ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
  1491. return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
  1492. }
  1493. ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
  1494. return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
  1495. }
  1496. ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
  1497. return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
  1498. }
  1499. ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1500. return kv->build_input_k_idxs(ctx, ubatch);
  1501. }
  1502. ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1503. return kv->build_input_v_idxs(ctx, ubatch);
  1504. }
  1505. void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
  1506. kv->set_input_k_shift(dst);
  1507. }
  1508. void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1509. kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
  1510. }
  1511. void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1512. kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
  1513. }
  1514. void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  1515. kv->set_input_kq_mask(dst, ubatch, causal_attn);
  1516. }
  1517. void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1518. kv->set_input_pos_bucket(dst, ubatch);
  1519. }
  1520. uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
  1521. // the FA kernels require padding to avoid extra runtime boundary checks
  1522. return cparams.flash_attn ? 256u : 32u;
  1523. }