llama-kv-cache.cpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012
  1. #include "llama-kv-cache.h"
  2. #include "llama-impl.h"
  3. #include "llama-io.h"
  4. #include "llama-model.h"
  5. #include "llama-context.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <limits>
  10. #include <map>
  11. #include <stdexcept>
  12. //
  13. // llama_kv_cache
  14. //
  15. llama_kv_cache::llama_kv_cache(
  16. const llama_model & model,
  17. ggml_type type_k,
  18. ggml_type type_v,
  19. bool v_trans,
  20. bool offload,
  21. bool unified,
  22. uint32_t kv_size,
  23. uint32_t n_seq_max,
  24. uint32_t n_pad,
  25. uint32_t n_swa,
  26. llama_swa_type swa_type,
  27. const layer_filter_cb & filter,
  28. const layer_reuse_cb & reuse) :
  29. model(model), hparams(model.hparams), v_trans(v_trans),
  30. n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  31. GGML_ASSERT(kv_size % n_pad == 0);
  32. const uint32_t n_layer_kv = hparams.n_layer_kv();
  33. // create a context for each buffer type
  34. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  35. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  36. auto it = ctx_map.find(buft);
  37. if (it == ctx_map.end()) {
  38. ggml_init_params params = {
  39. /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
  40. /*.mem_buffer =*/ NULL,
  41. /*.no_alloc =*/ true,
  42. };
  43. ggml_context * ctx = ggml_init(params);
  44. if (!ctx) {
  45. return nullptr;
  46. }
  47. ctx_map[buft] = ctx;
  48. ctxs.emplace_back(ctx);
  49. return ctx;
  50. }
  51. return it->second;
  52. };
  53. GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
  54. v_heads.resize(n_stream);
  55. for (uint32_t s = 0; s < n_stream; ++s) {
  56. v_heads[s] = 0;
  57. }
  58. v_cells.resize(n_stream);
  59. for (uint32_t s = 0; s < n_stream; ++s) {
  60. v_cells[s].resize(kv_size);
  61. }
  62. // by default, all sequence ids are mapped to the 0th stream
  63. seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
  64. if (n_stream > 1) {
  65. seq_to_stream.resize(n_stream, 0);
  66. for (uint32_t s = 0; s < n_stream; ++s) {
  67. seq_to_stream[s] = s;
  68. }
  69. }
  70. // [TAG_V_CACHE_VARIABLE]
  71. if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
  72. LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
  73. __func__, hparams.n_embd_v_gqa_max());
  74. }
  75. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  76. if (!hparams.has_kv(il)) {
  77. LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
  78. continue;
  79. }
  80. if (filter && !filter(il)) {
  81. LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
  82. continue;
  83. }
  84. // [TAG_V_CACHE_VARIABLE]
  85. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  86. const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
  87. const char * dev_name = "CPU";
  88. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  89. if (offload) {
  90. auto * dev = model.dev_layer(il);
  91. buft = ggml_backend_dev_buffer_type(dev);
  92. dev_name = ggml_backend_dev_name(dev);
  93. }
  94. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  95. ggml_context * ctx = ctx_for_buft(buft);
  96. if (!ctx) {
  97. throw std::runtime_error("failed to create ggml context for kv cache");
  98. }
  99. ggml_tensor * k;
  100. ggml_tensor * v;
  101. k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
  102. v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
  103. ggml_format_name(k, "cache_k_l%d", il);
  104. ggml_format_name(v, "cache_v_l%d", il);
  105. std::vector<ggml_tensor *> k_stream;
  106. std::vector<ggml_tensor *> v_stream;
  107. for (uint32_t s = 0; s < n_stream; ++s) {
  108. k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
  109. v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
  110. }
  111. map_layer_ids[il] = layers.size();
  112. layers.push_back({ il, k, v, k_stream, v_stream, });
  113. }
  114. if (reuse) {
  115. LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
  116. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  117. const int32_t il_reuse = reuse(il);
  118. if (il_reuse < 0) {
  119. LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
  120. continue;
  121. }
  122. if (filter && !filter(il)) {
  123. LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
  124. continue;
  125. }
  126. GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
  127. map_layer_ids[il] = map_layer_ids[il_reuse];
  128. LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
  129. }
  130. }
  131. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  132. for (auto it : ctx_map) {
  133. auto * buft = it.first;
  134. auto * ctx = it.second;
  135. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  136. if (!buf) {
  137. throw std::runtime_error("failed to allocate buffer for kv cache");
  138. }
  139. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  140. ggml_backend_buffer_clear(buf, 0);
  141. bufs.emplace_back(buf);
  142. }
  143. {
  144. const size_t memory_size_k = size_k_bytes();
  145. const size_t memory_size_v = size_v_bytes();
  146. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  147. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
  148. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  149. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  150. }
  151. const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
  152. debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
  153. }
  154. void llama_kv_cache::clear(bool data) {
  155. for (uint32_t s = 0; s < n_stream; ++s) {
  156. v_cells[s].reset();
  157. v_heads[s] = 0;
  158. }
  159. if (data) {
  160. for (auto & buf : bufs) {
  161. ggml_backend_buffer_clear(buf.get(), 0);
  162. }
  163. }
  164. }
  165. bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  166. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  167. if (p0 < 0) {
  168. p0 = 0;
  169. }
  170. if (p1 < 0) {
  171. p1 = std::numeric_limits<llama_pos>::max();
  172. }
  173. if (seq_id >= 0) {
  174. auto & cells = v_cells[seq_to_stream[seq_id]];
  175. auto & head = v_heads[seq_to_stream[seq_id]];
  176. uint32_t new_head = cells.size();
  177. for (uint32_t i = 0; i < cells.size(); ++i) {
  178. if (!cells.pos_in(i, p0, p1)) {
  179. continue;
  180. }
  181. if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
  182. if (new_head == cells.size()) {
  183. new_head = i;
  184. }
  185. }
  186. }
  187. // If we freed up a slot, set head to it so searching can start there.
  188. if (new_head != cells.size() && new_head < head) {
  189. head = new_head;
  190. }
  191. } else {
  192. // match any sequence
  193. for (uint32_t s = 0; s < n_stream; ++s) {
  194. auto & cells = v_cells[s];
  195. auto & head = v_heads[s];
  196. uint32_t new_head = cells.size();
  197. for (uint32_t i = 0; i < cells.size(); ++i) {
  198. if (!cells.pos_in(i, p0, p1)) {
  199. continue;
  200. }
  201. cells.rm(i);
  202. if (new_head == cells.size()) {
  203. new_head = i;
  204. }
  205. }
  206. // If we freed up a slot, set head to it so searching can start there.
  207. if (new_head != cells.size() && new_head < head) {
  208. head = new_head;
  209. }
  210. }
  211. }
  212. return true;
  213. }
  214. void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  215. GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
  216. GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
  217. const auto s0 = seq_to_stream[seq_id_src];
  218. const auto s1 = seq_to_stream[seq_id_dst];
  219. if (s0 == s1) {
  220. // since both sequences are in the same stream, no data copy is necessary
  221. // we just have to update the cells meta data
  222. auto & cells = v_cells[s0];
  223. if (seq_id_src == seq_id_dst) {
  224. return;
  225. }
  226. if (p0 < 0) {
  227. p0 = 0;
  228. }
  229. if (p1 < 0) {
  230. p1 = std::numeric_limits<llama_pos>::max();
  231. }
  232. for (uint32_t i = 0; i < cells.size(); ++i) {
  233. if (!cells.pos_in(i, p0, p1)) {
  234. continue;
  235. }
  236. if (cells.seq_has(i, seq_id_src)) {
  237. cells.seq_add(i, seq_id_dst);
  238. }
  239. }
  240. return;
  241. }
  242. // cross-stream sequence copies require to copy the actual buffer data
  243. bool is_full = true;
  244. if (p0 > 0 && p0 + 1 < (int) get_size()) {
  245. is_full = false;
  246. }
  247. if (p1 > 0 && p1 + 1 < (int) get_size()) {
  248. is_full = false;
  249. }
  250. GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
  251. // enqueue the copy operation - the buffer copy will be performed during the next update
  252. sc_info.ssrc.push_back(s0);
  253. sc_info.sdst.push_back(s1);
  254. v_cells[s1].reset();
  255. for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
  256. if (v_cells[s0].seq_has(i, seq_id_src)) {
  257. llama_pos pos = v_cells[s0].pos_get(i);
  258. llama_pos shift = v_cells[s0].get_shift(i);
  259. if (shift != 0) {
  260. pos -= shift;
  261. assert(pos >= 0);
  262. }
  263. v_cells[s1].pos_set(i, pos);
  264. v_cells[s1].seq_add(i, seq_id_dst);
  265. if (shift != 0) {
  266. v_cells[s1].pos_add(i, shift);
  267. }
  268. }
  269. }
  270. v_heads[s1] = v_heads[s0];
  271. //for (uint32_t s = 0; s < n_stream; ++s) {
  272. // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
  273. //}
  274. }
  275. void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
  276. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  277. auto & cells = v_cells[seq_to_stream[seq_id]];
  278. auto & head = v_heads[seq_to_stream[seq_id]];
  279. uint32_t new_head = cells.size();
  280. for (uint32_t i = 0; i < cells.size(); ++i) {
  281. if (cells.seq_keep(i, seq_id)) {
  282. if (new_head == cells.size()) {
  283. new_head = i;
  284. }
  285. }
  286. }
  287. // If we freed up a slot, set head to it so searching can start there.
  288. if (new_head != cells.size() && new_head < head) {
  289. head = new_head;
  290. }
  291. }
  292. void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
  293. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  294. auto & cells = v_cells[seq_to_stream[seq_id]];
  295. auto & head = v_heads[seq_to_stream[seq_id]];
  296. if (shift == 0) {
  297. return;
  298. }
  299. uint32_t new_head = cells.size();
  300. if (p0 < 0) {
  301. p0 = 0;
  302. }
  303. if (p1 < 0) {
  304. p1 = std::numeric_limits<llama_pos>::max();
  305. }
  306. // If there is no range then return early to avoid looping over all cells.
  307. if (p0 == p1) {
  308. return;
  309. }
  310. for (uint32_t i = 0; i < cells.size(); ++i) {
  311. if (!cells.pos_in(i, p0, p1)) {
  312. continue;
  313. }
  314. if (cells.seq_has(i, seq_id)) {
  315. if (cells.pos_add(i, shift)) {
  316. if (new_head == cells.size()) {
  317. new_head = i;
  318. }
  319. }
  320. }
  321. }
  322. // If we freed up a slot, set head to it so searching can start there.
  323. // Otherwise we just start the next search from the beginning.
  324. head = new_head != cells.size() ? new_head : 0;
  325. }
  326. void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  327. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  328. auto & cells = v_cells[seq_to_stream[seq_id]];
  329. if (d == 1) {
  330. return;
  331. }
  332. if (p0 < 0) {
  333. p0 = 0;
  334. }
  335. if (p1 < 0) {
  336. p1 = std::numeric_limits<llama_pos>::max();
  337. }
  338. // If there is no range then return early to avoid looping over the cache.
  339. if (p0 == p1) {
  340. return;
  341. }
  342. for (uint32_t i = 0; i < cells.size(); ++i) {
  343. if (!cells.pos_in(i, p0, p1)) {
  344. continue;
  345. }
  346. if (cells.seq_has(i, seq_id)) {
  347. cells.pos_div(i, d);
  348. }
  349. }
  350. }
  351. llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
  352. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  353. const auto & cells = v_cells[seq_to_stream[seq_id]];
  354. return cells.seq_pos_min(seq_id);
  355. }
  356. llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
  357. GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
  358. const auto & cells = v_cells[seq_to_stream[seq_id]];
  359. return cells.seq_pos_max(seq_id);
  360. }
  361. llama_memory_context_ptr llama_kv_cache::init_batch(
  362. llama_batch_allocr & balloc,
  363. uint32_t n_ubatch,
  364. bool embd_all) {
  365. GGML_UNUSED(embd_all);
  366. do {
  367. balloc.split_reset();
  368. std::vector<llama_ubatch> ubatches;
  369. while (true) {
  370. auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
  371. if (ubatch.n_tokens == 0) {
  372. break;
  373. }
  374. ubatches.push_back(std::move(ubatch)); // NOLINT
  375. }
  376. if (balloc.get_n_used() < balloc.get_n_tokens()) {
  377. // failed to find a suitable split
  378. break;
  379. }
  380. auto sinfos = prepare(ubatches);
  381. if (sinfos.empty()) {
  382. break;
  383. }
  384. return std::make_unique<llama_kv_cache_context>(
  385. this, std::move(sinfos), std::move(ubatches));
  386. } while (false);
  387. return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
  388. }
  389. llama_memory_context_ptr llama_kv_cache::init_full() {
  390. return std::make_unique<llama_kv_cache_context>(this);
  391. }
  392. llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
  393. GGML_UNUSED(optimize);
  394. bool do_shift = get_has_shift();
  395. return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
  396. }
  397. llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
  398. llama_kv_cache::slot_info_vec_t res;
  399. struct state_t {
  400. slot_info sinfo; // slot info for the ubatch
  401. std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
  402. std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
  403. };
  404. // remember the old state of the cells so we can restore it in the end
  405. std::vector<state_t> states;
  406. bool success = true;
  407. for (const auto & ubatch : ubatches) {
  408. // only find a suitable slot for the ubatch. don't modify the cells yet
  409. const auto sinfo_new = find_slot(ubatch, false);
  410. if (sinfo_new.empty()) {
  411. success = false;
  412. break;
  413. }
  414. // remeber the position that we found
  415. res.push_back(sinfo_new);
  416. // store the old state of the cells in the recovery stack
  417. {
  418. state_t state = { sinfo_new, v_heads, {} };
  419. for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
  420. auto & cells = v_cells[sinfo_new.strm[s]];
  421. state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
  422. }
  423. states.push_back(std::move(state));
  424. }
  425. // now emplace the ubatch
  426. apply_ubatch(sinfo_new, ubatch);
  427. }
  428. GGML_ASSERT(!states.empty() || !success);
  429. // iterate backwards and restore the cells to their original state
  430. for (auto it = states.rbegin(); it != states.rend(); ++it) {
  431. const auto & sinfo = it->sinfo;
  432. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  433. auto & cells = v_cells[sinfo.strm[s]];
  434. auto & head = v_heads[sinfo.strm[s]];
  435. cells.set(sinfo.idxs[s], it->v_cells[s]);
  436. head = it->v_heads_old[s];
  437. }
  438. }
  439. if (!success) {
  440. return {};
  441. }
  442. return res;
  443. }
  444. bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
  445. bool updated = false;
  446. auto * sched = lctx->get_sched();
  447. if (!sc_info.empty()) {
  448. assert(n_stream > 1 && "stream copy should never happen with a single stream");
  449. llama_synchronize(lctx);
  450. const size_t n_copy = sc_info.ssrc.size();
  451. for (size_t i = 0; i < n_copy; ++i) {
  452. const auto ssrc = sc_info.ssrc[i];
  453. const auto sdst = sc_info.sdst[i];
  454. assert(ssrc < n_stream);
  455. assert(sdst < n_stream);
  456. LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
  457. assert(ssrc != sdst);
  458. for (uint32_t il = 0; il < layers.size(); ++il) {
  459. const auto & layer = layers[il];
  460. ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
  461. ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
  462. }
  463. }
  464. }
  465. if (do_shift) {
  466. if (!get_can_shift()) {
  467. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  468. }
  469. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  470. // apply K-shift if needed
  471. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  472. ggml_backend_sched_reset(sched);
  473. auto * res = lctx->get_gf_res_reserve();
  474. res->reset();
  475. auto * gf = build_graph_shift(res, lctx);
  476. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  477. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
  478. return updated;
  479. }
  480. res->set_inputs(nullptr);
  481. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  482. LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
  483. return updated;
  484. }
  485. updated = true;
  486. }
  487. for (uint32_t s = 0; s < n_stream; ++s) {
  488. auto & cells = v_cells[s];
  489. cells.reset_shift();
  490. }
  491. }
  492. return updated;
  493. }
  494. llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
  495. if (debug > 0) {
  496. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  497. const auto seq_id = ubatch.seq_id_unq[s];
  498. const auto stream_id = seq_to_stream[seq_id];
  499. const auto & cells = v_cells[stream_id];
  500. const uint32_t head_cur = v_heads[stream_id];
  501. LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
  502. __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
  503. if ((debug == 2 && n_swa > 0) || debug > 2) {
  504. std::string ss;
  505. for (uint32_t i = 0; i < cells.size(); ++i) {
  506. if (cells.is_empty(i)) {
  507. ss += '.';
  508. } else {
  509. assert(cells.seq_count(i) >= 1);
  510. if (cells.seq_count(i) == 1) {
  511. ss += std::to_string(cells.seq_get(i));
  512. } else {
  513. ss += 'M';
  514. }
  515. }
  516. if (i%256 == 255) {
  517. ss += " *";
  518. ss += '\n';
  519. }
  520. }
  521. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  522. }
  523. if ((debug == 2 && n_swa > 0) || debug > 2) {
  524. std::string ss;
  525. for (uint32_t i = 0; i < cells.size(); ++i) {
  526. std::string cur;
  527. if (cells.is_empty(i)) {
  528. cur = '.';
  529. } else {
  530. cur = std::to_string(cells.pos_get(i));
  531. }
  532. const int n = cur.size();
  533. for (int j = 0; j < 5 - n; ++j) {
  534. cur += ' ';
  535. }
  536. ss += cur;
  537. if (i%256 == 255) {
  538. ss += " *";
  539. }
  540. if (i%64 == 63) {
  541. ss += '\n';
  542. }
  543. }
  544. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  545. }
  546. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  547. if (cells.seq_pos_min(s) < 0) {
  548. continue;
  549. }
  550. LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
  551. }
  552. }
  553. }
  554. uint32_t n_tokens = ubatch.n_tokens;
  555. uint32_t n_seqs = 1;
  556. if (n_stream > 1) {
  557. GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
  558. n_seqs = ubatch.n_seqs_unq;
  559. n_tokens = n_tokens / n_seqs;
  560. }
  561. slot_info res = {
  562. /*.s0 =*/ LLAMA_MAX_SEQ,
  563. /*.s1 =*/ 0,
  564. /*.strm =*/ { },
  565. /*.idxs =*/ { },
  566. };
  567. res.resize(n_seqs);
  568. for (uint32_t s = 0; s < n_seqs; ++s) {
  569. const auto seq_id = ubatch.seq_id_unq[s];
  570. if (n_stream > 1) {
  571. GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
  572. GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
  573. }
  574. res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
  575. res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
  576. res.strm[s] = seq_to_stream[seq_id];
  577. res.idxs[s].reserve(n_tokens);
  578. const auto & cells = v_cells[seq_to_stream[seq_id]];
  579. uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
  580. // if we have enough unused cells before the current head ->
  581. // better to start searching from the beginning of the cache, hoping to fill it
  582. if (head_cur > cells.get_used() + 2*n_tokens) {
  583. head_cur = 0;
  584. }
  585. if (n_tokens > cells.size()) {
  586. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
  587. return { };
  588. }
  589. uint32_t n_tested = 0;
  590. // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
  591. // for non-continuous slots, we test the tokens one by one
  592. const uint32_t n_test = cont ? n_tokens : 1;
  593. while (true) {
  594. if (head_cur + n_test > cells.size()) {
  595. n_tested += cells.size() - head_cur;
  596. head_cur = 0;
  597. continue;
  598. }
  599. for (uint32_t i = 0; i < n_test; i++) {
  600. const auto idx = head_cur;
  601. head_cur++;
  602. n_tested++;
  603. //const llama_pos pos = ubatch.pos[i];
  604. //const llama_seq_id seq_id = ubatch.seq_id[i][0];
  605. // can we use this cell? either:
  606. // - the cell is empty
  607. // - the cell is occupied only by one sequence:
  608. // - (disabled) mask causally, if the sequence is the same as the one we are inserting
  609. // - mask SWA, using current max pos for that sequence in the cache
  610. // always insert in the cell with minimum pos
  611. bool can_use = cells.is_empty(idx);
  612. if (!can_use && cells.seq_count(idx) == 1) {
  613. const llama_pos pos_cell = cells.pos_get(idx);
  614. // (disabled) causal mask
  615. // note: it's better to purge any "future" tokens beforehand
  616. //if (cells.seq_has(idx, seq_id)) {
  617. // can_use = pos_cell >= pos;
  618. //}
  619. if (!can_use) {
  620. const llama_seq_id seq_id_cell = cells.seq_get(idx);
  621. // SWA mask
  622. if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
  623. can_use = true;
  624. }
  625. }
  626. }
  627. if (can_use) {
  628. res.idxs[s].push_back(idx);
  629. } else {
  630. if (cont) {
  631. break;
  632. }
  633. }
  634. }
  635. if (res.idxs[s].size() == n_tokens) {
  636. break;
  637. }
  638. if (cont) {
  639. res.idxs[s].clear();
  640. }
  641. if (n_tested >= cells.size()) {
  642. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  643. return { };
  644. }
  645. }
  646. // we didn't find a suitable slot - return empty result
  647. if (res.idxs[s].size() < n_tokens) {
  648. return { };
  649. }
  650. }
  651. assert(res.s1 >= res.s0);
  652. return res;
  653. }
  654. void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
  655. // keep track of the max sequence position that we would overwrite with this ubatch
  656. // for non-SWA cache, this would be always empty
  657. llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
  658. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  659. seq_pos_max_rm[s] = -1;
  660. }
  661. assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
  662. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  663. for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
  664. const uint32_t i = s*sinfo.size() + ii;
  665. auto & cells = v_cells[sinfo.strm[s]];
  666. const auto idx = sinfo.idxs[s][ii];
  667. if (!cells.is_empty(idx)) {
  668. assert(cells.seq_count(idx) == 1);
  669. const llama_seq_id seq_id = cells.seq_get(idx);
  670. const llama_pos pos = cells.pos_get(idx);
  671. seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
  672. cells.rm(idx);
  673. }
  674. cells.pos_set(idx, ubatch.pos[i]);
  675. for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
  676. cells.seq_add(idx, ubatch.seq_id[i][s]);
  677. }
  678. }
  679. }
  680. // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
  681. // will be present in the cache. so we have to purge any position which is less than those we would overwrite
  682. // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
  683. for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
  684. if (seq_pos_max_rm[s] == -1) {
  685. continue;
  686. }
  687. GGML_ASSERT(s < seq_to_stream.size());
  688. auto & cells = v_cells[seq_to_stream[s]];
  689. if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
  690. LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
  691. __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
  692. seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
  693. }
  694. }
  695. // move the head at the end of the slot
  696. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  697. auto & head = v_heads[sinfo.strm[s]];
  698. head = sinfo.idxs[s].back() + 1;
  699. }
  700. }
  701. bool llama_kv_cache::get_can_shift() const {
  702. return true;
  703. }
  704. uint32_t llama_kv_cache::get_size() const {
  705. const auto & cells = v_cells[seq_to_stream[0]];
  706. return cells.size();
  707. }
  708. uint32_t llama_kv_cache::get_n_stream() const {
  709. return n_stream;
  710. }
  711. bool llama_kv_cache::get_has_shift() const {
  712. bool result = false;
  713. for (uint32_t s = 0; s < n_stream; ++s) {
  714. result |= v_cells[s].get_has_shift();
  715. }
  716. return result;
  717. }
  718. uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
  719. uint32_t result = 0;
  720. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  721. const auto & cells = v_cells[sinfo.strm[s]];
  722. result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
  723. }
  724. return result;
  725. }
  726. ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  727. const int32_t ikv = map_layer_ids.at(il);
  728. auto * k = layers[ikv].k;
  729. const uint64_t kv_size = get_size();
  730. const uint64_t n_embd_k_gqa = k->ne[0];
  731. assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
  732. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  733. return ggml_view_4d(ctx, k,
  734. hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
  735. ggml_row_size(k->type, hparams.n_embd_head_k),
  736. ggml_row_size(k->type, n_embd_k_gqa),
  737. ggml_row_size(k->type, n_embd_k_gqa*kv_size),
  738. ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
  739. }
  740. ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
  741. const int32_t ikv = map_layer_ids.at(il);
  742. auto * v = layers[ikv].v;
  743. const uint64_t kv_size = get_size();
  744. const uint64_t n_embd_v_gqa = v->ne[0];
  745. // [TAG_V_CACHE_VARIABLE]
  746. assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
  747. const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
  748. if (!v_trans) {
  749. // note: v->nb[1] <= v->nb[2]
  750. return ggml_view_4d(ctx, v,
  751. hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
  752. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  753. ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
  754. ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
  755. ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
  756. }
  757. // note: v->nb[1] > v->nb[2]
  758. return ggml_view_4d(ctx, v,
  759. n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
  760. ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
  761. ggml_row_size(v->type, kv_size), // v->nb[2]
  762. ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
  763. ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
  764. }
  765. ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
  766. GGML_UNUSED(sinfo);
  767. const int32_t ikv = map_layer_ids.at(il);
  768. ggml_tensor * k = layers[ikv].k;
  769. const int64_t n_embd_head = k_cur->ne[0];
  770. const int64_t n_head = k_cur->ne[1];
  771. const int64_t n_tokens = k_cur->ne[2];
  772. const int64_t n_embd_gqa = n_embd_head*n_head;
  773. // we can merge dims 0 and 1
  774. // TODO: add ggml helper function for this?
  775. GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
  776. k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
  777. const int64_t n_stream = k->ne[2];
  778. if (n_stream > 1) {
  779. const int64_t kv_size = get_size();
  780. assert(n_embd_gqa == k->ne[0]);
  781. assert(kv_size == k->ne[1]);
  782. // merge the buffer across all streams because the idxs are global
  783. k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
  784. }
  785. // store the current K values into the cache
  786. return ggml_set_rows(ctx, k, k_cur, k_idxs);
  787. }
  788. ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
  789. GGML_UNUSED(sinfo);
  790. const int32_t ikv = map_layer_ids.at(il);
  791. auto * v = layers[ikv].v;
  792. const int64_t n_embd_head = v_cur->ne[0];
  793. const int64_t n_head = v_cur->ne[1];
  794. const int64_t n_tokens = v_cur->ne[2];
  795. const int64_t n_embd_gqa = n_embd_head*n_head;
  796. // we can merge dims 0 and 1
  797. GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
  798. const int64_t n_stream = v->ne[2];
  799. // take this branch when FA is enabled (the V cache is not transposed)
  800. if (!v_trans) {
  801. v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
  802. if (n_stream > 1) {
  803. const int64_t kv_size = get_size();
  804. assert(n_embd_gqa == v->ne[0]);
  805. assert(kv_size == v->ne[1]);
  806. // merge the buffer across all streams because the idxs are global
  807. v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
  808. }
  809. return ggml_set_rows(ctx, v, v_cur, v_idxs);
  810. }
  811. if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
  812. // we can merge dims 0, 1 and 2
  813. v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
  814. } else {
  815. // otherwise -> make a copy to get contiguous data
  816. v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
  817. }
  818. // [TAG_V_CACHE_VARIABLE]
  819. if (n_embd_gqa < v->ne[0]) {
  820. v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
  821. }
  822. // in this branch the v_idxs are constructed in such a way that each row is a single head element
  823. ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
  824. v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
  825. return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
  826. }
  827. ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  828. const uint32_t n_tokens = ubatch.n_tokens;
  829. ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  830. ggml_set_input(k_idxs);
  831. return k_idxs;
  832. }
  833. ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  834. const uint32_t n_tokens = ubatch.n_tokens;
  835. ggml_tensor * v_idxs;
  836. if (!v_trans) {
  837. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  838. } else {
  839. v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
  840. }
  841. ggml_set_input(v_idxs);
  842. return v_idxs;
  843. }
  844. void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  845. const uint32_t n_tokens = ubatch->n_tokens;
  846. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  847. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  848. int64_t * data = (int64_t *) dst->data;
  849. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  850. const int64_t offs = sinfo.strm[s]*get_size();
  851. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  852. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  853. }
  854. }
  855. }
  856. void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  857. const uint32_t n_tokens = ubatch->n_tokens;
  858. GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
  859. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  860. int64_t * data = (int64_t *) dst->data;
  861. if (!v_trans) {
  862. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  863. const int64_t offs = sinfo.strm[s]*get_size();
  864. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  865. data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
  866. }
  867. }
  868. } else {
  869. // note: the V cache is transposed when not using flash attention
  870. const int64_t kv_size = get_size();
  871. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
  872. for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
  873. const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
  874. for (uint32_t i = 0; i < sinfo.size(); ++i) {
  875. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  876. data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
  877. }
  878. }
  879. }
  880. }
  881. }
  882. void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
  883. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  884. int32_t * data = (int32_t *) dst->data;
  885. for (uint32_t s = 0; s < n_stream; ++s) {
  886. const auto & cells = v_cells[s];
  887. for (uint32_t i = 0; i < cells.size(); ++i) {
  888. data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
  889. }
  890. }
  891. }
  892. void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  893. const uint32_t n_tokens = ubatch->n_tokens;
  894. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  895. float * data = (float *) dst->data;
  896. const int64_t n_kv = dst->ne[0];
  897. const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
  898. GGML_ASSERT(n_tokens%n_stream == 0);
  899. // n_tps == n_tokens_per_stream
  900. const int64_t n_tps = n_tokens/n_stream;
  901. const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
  902. std::fill(data, data + ggml_nelements(dst), -INFINITY);
  903. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  904. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  905. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  906. // Causal mask:
  907. // xxx-------
  908. // xxxx------
  909. // xxxxx-----
  910. // Non-causal mask:
  911. // xxxxx-----
  912. // xxxxx-----
  913. // xxxxx-----
  914. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  915. // TODO: optimize this section
  916. for (uint32_t h = 0; h < 1; ++h) {
  917. for (uint32_t s = 0; s < n_stream; ++s) {
  918. for (uint32_t ii = 0; ii < n_tps; ++ii) {
  919. const uint32_t i = s*n_tps + ii;
  920. const llama_seq_id seq_id = ubatch->seq_id[i][0];
  921. const auto & cells = v_cells[seq_to_stream[seq_id]];
  922. const llama_pos p1 = ubatch->pos[i];
  923. const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
  924. for (uint32_t j = 0; j < n_kv; ++j) {
  925. if (cells.is_empty(j)) {
  926. continue;
  927. }
  928. // mask the token if not the same sequence
  929. if (!cells.seq_has(j, seq_id)) {
  930. continue;
  931. }
  932. const llama_pos p0 = cells.pos_get(j);
  933. // mask future tokens
  934. if (causal_attn && p0 > p1) {
  935. continue;
  936. }
  937. // apply SWA if any
  938. if (is_masked_swa(p0, p1)) {
  939. continue;
  940. }
  941. data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
  942. }
  943. }
  944. }
  945. }
  946. }
  947. void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  948. const int64_t n_tokens = ubatch->n_tokens;
  949. GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
  950. const auto & cells = v_cells[0];
  951. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  952. GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
  953. int32_t * data = (int32_t *) dst->data;
  954. const int32_t n_kv = dst->ne[0];
  955. for (int h = 0; h < 1; ++h) {
  956. for (int i = 0; i < n_tokens; ++i) {
  957. for (int j = 0; j < n_kv; ++j) {
  958. // the position when the cells is empty is irrelevant - it will be masked out later in the attention
  959. const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
  960. data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
  961. }
  962. }
  963. }
  964. }
  965. size_t llama_kv_cache::total_size() const {
  966. size_t size = 0;
  967. for (const auto & buf : bufs) {
  968. size += ggml_backend_buffer_get_size(buf.get());
  969. }
  970. return size;
  971. }
  972. size_t llama_kv_cache::size_k_bytes() const {
  973. size_t size_k_bytes = 0;
  974. for (const auto & layer : layers) {
  975. size_k_bytes += ggml_nbytes(layer.k);
  976. }
  977. return size_k_bytes;
  978. }
  979. size_t llama_kv_cache::size_v_bytes() const {
  980. size_t size_v_bytes = 0;
  981. for (const auto & layer : layers) {
  982. size_v_bytes += ggml_nbytes(layer.v);
  983. }
  984. return size_v_bytes;
  985. }
  986. ggml_tensor * llama_kv_cache::build_rope_shift(
  987. const llama_cparams & cparams,
  988. ggml_context * ctx,
  989. ggml_tensor * cur,
  990. ggml_tensor * shift,
  991. ggml_tensor * factors,
  992. float freq_base,
  993. float freq_scale) const {
  994. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  995. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  996. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  997. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  998. const auto & n_rot = hparams.n_rot;
  999. const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
  1000. // @ngxson : this is a workaround
  1001. // for M-RoPE, we want to rotate the whole vector when doing KV shift
  1002. // a normal RoPE should work, we just need to use the correct ordering
  1003. // ref: https://github.com/ggml-org/llama.cpp/pull/13870
  1004. ? LLAMA_ROPE_TYPE_NEOX
  1005. : hparams.rope_type;
  1006. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  1007. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  1008. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
  1009. ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
  1010. : cparams.yarn_attn_factor;
  1011. ggml_tensor * tmp;
  1012. if (ggml_is_quantized(cur->type)) {
  1013. // dequantize to f32 -> RoPE -> quantize back
  1014. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  1015. tmp = ggml_rope_ext(ctx, tmp,
  1016. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1017. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1018. tmp = ggml_cpy(ctx, tmp, cur);
  1019. } else {
  1020. // we rotate only the first n_rot dimensions
  1021. tmp = ggml_rope_ext_inplace(ctx, cur,
  1022. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  1023. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  1024. }
  1025. return tmp;
  1026. }
  1027. class llm_graph_input_k_shift : public llm_graph_input_i {
  1028. public:
  1029. llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
  1030. virtual ~llm_graph_input_k_shift() = default;
  1031. void set_input(const llama_ubatch * ubatch) override;
  1032. ggml_tensor * k_shift; // I32 [kv_size*n_stream]
  1033. const llama_kv_cache * kv_self;
  1034. };
  1035. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  1036. GGML_UNUSED(ubatch);
  1037. if (k_shift) {
  1038. kv_self->set_input_k_shift(k_shift);
  1039. }
  1040. }
  1041. ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
  1042. auto * ctx = res->get_ctx();
  1043. auto * gf = res->get_gf();
  1044. const auto & n_embd_head_k = hparams.n_embd_head_k;
  1045. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  1046. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  1047. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
  1048. ggml_set_input(inp->k_shift);
  1049. const auto & cparams = lctx->get_cparams();
  1050. for (const auto & layer : layers) {
  1051. const uint32_t il = layer.il;
  1052. const int64_t n_head_kv = hparams.n_head_kv(il);
  1053. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1054. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  1055. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  1056. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  1057. ggml_tensor * k =
  1058. ggml_view_3d(ctx, layer.k,
  1059. n_embd_head_k, n_head_kv, get_size()*n_stream,
  1060. ggml_row_size(layer.k->type, n_embd_head_k),
  1061. ggml_row_size(layer.k->type, n_embd_k_gqa),
  1062. 0);
  1063. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  1064. ggml_build_forward_expand(gf, cur);
  1065. }
  1066. res->add_input(std::move(inp));
  1067. return gf;
  1068. }
  1069. bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
  1070. return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
  1071. }
  1072. void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
  1073. GGML_UNUSED(flags);
  1074. io.write(&n_stream, sizeof(n_stream));
  1075. for (uint32_t s = 0; s < n_stream; ++s) {
  1076. cell_ranges_t cr { s, {} };
  1077. uint32_t cell_count = 0;
  1078. const auto & cells = v_cells[s];
  1079. // Count the number of cells with the specified seq_id
  1080. // Find all the ranges of cells with this seq id (or all, when -1)
  1081. uint32_t cell_range_begin = cells.size();
  1082. for (uint32_t i = 0; i < cells.size(); ++i) {
  1083. if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
  1084. ++cell_count;
  1085. if (cell_range_begin == cells.size()) {
  1086. cell_range_begin = i;
  1087. }
  1088. } else {
  1089. if (cell_range_begin != cells.size()) {
  1090. cr.data.emplace_back(cell_range_begin, i);
  1091. cell_range_begin = cells.size();
  1092. }
  1093. }
  1094. }
  1095. if (cell_range_begin != cells.size()) {
  1096. cr.data.emplace_back(cell_range_begin, cells.size());
  1097. }
  1098. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1099. uint32_t cell_count_check = 0;
  1100. for (const auto & range : cr.data) {
  1101. cell_count_check += range.second - range.first;
  1102. }
  1103. GGML_ASSERT(cell_count == cell_count_check);
  1104. io.write(&cell_count, sizeof(cell_count));
  1105. // skip empty streams
  1106. if (cell_count == 0) {
  1107. continue;
  1108. }
  1109. state_write_meta(io, cr, seq_id);
  1110. state_write_data(io, cr);
  1111. }
  1112. }
  1113. void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
  1114. GGML_UNUSED(flags);
  1115. GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
  1116. uint32_t n_stream_cur;
  1117. io.read_to(&n_stream_cur, sizeof(n_stream_cur));
  1118. if (n_stream_cur != n_stream) {
  1119. throw std::runtime_error("n_stream mismatch");
  1120. }
  1121. for (uint32_t s = 0; s < n_stream; ++s) {
  1122. uint32_t cell_count;
  1123. io.read_to(&cell_count, sizeof(cell_count));
  1124. if (cell_count == 0) {
  1125. continue;
  1126. }
  1127. const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
  1128. bool res = true;
  1129. res = res && state_read_meta(io, strm, cell_count, seq_id);
  1130. res = res && state_read_data(io, strm, cell_count);
  1131. if (!res) {
  1132. if (seq_id == -1) {
  1133. clear(true);
  1134. } else {
  1135. seq_rm(seq_id, -1, -1);
  1136. }
  1137. throw std::runtime_error("failed to restore kv cache");
  1138. }
  1139. }
  1140. }
  1141. void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
  1142. const auto & cells = v_cells[cr.strm];
  1143. for (const auto & range : cr.data) {
  1144. for (uint32_t i = range.first; i < range.second; ++i) {
  1145. std::vector<llama_seq_id> seq_ids;
  1146. for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
  1147. if (cur == seq_id || seq_id == -1) {
  1148. if (cells.seq_has(i, cur)) {
  1149. seq_ids.push_back(cur);
  1150. }
  1151. }
  1152. }
  1153. const llama_pos pos = cells.pos_get(i);
  1154. const uint32_t n_seq_id = seq_ids.size();
  1155. io.write(&pos, sizeof(pos));
  1156. io.write(&n_seq_id, sizeof(n_seq_id));
  1157. for (const auto & seq_id : seq_ids) {
  1158. io.write(&seq_id, sizeof(seq_id));
  1159. }
  1160. }
  1161. }
  1162. }
  1163. void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
  1164. const auto & cells = v_cells[cr.strm];
  1165. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1166. const uint32_t n_layer = layers.size();
  1167. io.write(&v_trans, sizeof(v_trans));
  1168. io.write(&n_layer, sizeof(n_layer));
  1169. std::vector<uint8_t> tmp_buf;
  1170. // Iterate and write all the keys first, each row is a cell
  1171. // Get whole range at a time
  1172. for (const auto & layer : layers) {
  1173. const uint32_t il = layer.il;
  1174. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1175. auto * k = layer.k_stream[cr.strm];
  1176. // Write key type
  1177. const int32_t k_type_i = (int32_t) k->type;
  1178. io.write(&k_type_i, sizeof(k_type_i));
  1179. // Write row size of key
  1180. const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1181. io.write(&k_size_row, sizeof(k_size_row));
  1182. // Read each range of cells of k_size length each into tmp_buf and write out
  1183. for (const auto & range : cr.data) {
  1184. const size_t range_size = range.second - range.first;
  1185. const size_t buf_size = range_size * k_size_row;
  1186. io.write_tensor(k, range.first * k_size_row, buf_size);
  1187. }
  1188. }
  1189. if (!v_trans) {
  1190. for (const auto & layer : layers) {
  1191. const uint32_t il = layer.il;
  1192. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1193. auto * v = layer.v_stream[cr.strm];
  1194. // Write value type
  1195. const int32_t v_type_i = (int32_t) v->type;
  1196. io.write(&v_type_i, sizeof(v_type_i));
  1197. // Write row size of value
  1198. const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1199. io.write(&v_size_row, sizeof(v_size_row));
  1200. // Read each range of cells of v_size length each into tmp_buf and write out
  1201. for (const auto & range : cr.data) {
  1202. const size_t range_size = range.second - range.first;
  1203. const size_t buf_size = range_size * v_size_row;
  1204. io.write_tensor(v, range.first * v_size_row, buf_size);
  1205. }
  1206. }
  1207. } else {
  1208. // When v is transposed, we also need the element size and get the element ranges from each row
  1209. const uint32_t kv_size = cells.size();
  1210. for (const auto & layer : layers) {
  1211. const uint32_t il = layer.il;
  1212. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1213. auto * v = layer.v_stream[cr.strm];
  1214. // Write value type
  1215. const int32_t v_type_i = (int32_t) v->type;
  1216. io.write(&v_type_i, sizeof(v_type_i));
  1217. // Write element size
  1218. const uint32_t v_size_el = ggml_type_size(v->type);
  1219. io.write(&v_size_el, sizeof(v_size_el));
  1220. // Write GQA embedding size
  1221. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1222. // For each row, we get the element values of each cell
  1223. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1224. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1225. for (const auto & range : cr.data) {
  1226. const size_t range_size = range.second - range.first;
  1227. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1228. const size_t buf_size = range_size * v_size_el;
  1229. io.write_tensor(v, src_offset, buf_size);
  1230. }
  1231. }
  1232. }
  1233. }
  1234. }
  1235. bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1236. auto & cells = v_cells[strm];
  1237. auto & head = v_heads[strm];
  1238. if (dest_seq_id != -1) {
  1239. // single sequence
  1240. seq_rm(dest_seq_id, -1, -1);
  1241. llama_batch_allocr balloc(hparams.n_pos_per_embd());
  1242. llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
  1243. ubatch.seq_id_unq[0] = dest_seq_id;
  1244. for (uint32_t i = 0; i < cell_count; ++i) {
  1245. llama_pos pos;
  1246. uint32_t n_seq_id;
  1247. io.read_to(&pos, sizeof(pos));
  1248. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1249. if (n_seq_id != 1) {
  1250. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1251. return false;
  1252. }
  1253. // read the sequence id, but directly discard it - we will use dest_seq_id instead
  1254. {
  1255. llama_seq_id seq_id;
  1256. io.read_to(&seq_id, sizeof(seq_id));
  1257. }
  1258. ubatch.pos[i] = pos;
  1259. ubatch.n_seq_id[i] = n_seq_id;
  1260. ubatch.seq_id[i] = &dest_seq_id;
  1261. }
  1262. const auto sinfo = find_slot(ubatch, true);
  1263. if (sinfo.empty()) {
  1264. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1265. return false;
  1266. }
  1267. apply_ubatch(sinfo, ubatch);
  1268. const auto head_cur = sinfo.head();
  1269. // keep the head at the old position because we will read the KV data into it in state_read_data()
  1270. head = head_cur;
  1271. LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
  1272. // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1273. // Assume that this is one contiguous block of cells
  1274. GGML_ASSERT(head_cur + cell_count <= cells.size());
  1275. GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
  1276. GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
  1277. GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
  1278. GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
  1279. } else {
  1280. // whole KV cache restore
  1281. if (cell_count > cells.size()) {
  1282. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1283. return false;
  1284. }
  1285. clear(true);
  1286. for (uint32_t i = 0; i < cell_count; ++i) {
  1287. llama_pos pos;
  1288. uint32_t n_seq_id;
  1289. io.read_to(&pos, sizeof(pos));
  1290. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1291. cells.pos_set(i, pos);
  1292. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1293. llama_seq_id seq_id;
  1294. io.read_to(&seq_id, sizeof(seq_id));
  1295. if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
  1296. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
  1297. return false;
  1298. }
  1299. cells.seq_add(i, seq_id);
  1300. }
  1301. }
  1302. head = 0;
  1303. }
  1304. return true;
  1305. }
  1306. bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
  1307. auto & cells = v_cells[strm];
  1308. auto & head = v_heads[strm];
  1309. uint32_t v_trans;
  1310. uint32_t n_layer;
  1311. io.read_to(&v_trans, sizeof(v_trans));
  1312. io.read_to(&n_layer, sizeof(n_layer));
  1313. if (n_layer != layers.size()) {
  1314. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1315. return false;
  1316. }
  1317. if (cell_count > cells.size()) {
  1318. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
  1319. return false;
  1320. }
  1321. if (this->v_trans != (bool) v_trans) {
  1322. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1323. return false;
  1324. }
  1325. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1326. for (const auto & layer : layers) {
  1327. const uint32_t il = layer.il;
  1328. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1329. auto * k = layer.k_stream[strm];
  1330. // Read type of key
  1331. int32_t k_type_i_ref;
  1332. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1333. const int32_t k_type_i = (int32_t) k->type;
  1334. if (k_type_i != k_type_i_ref) {
  1335. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1336. return false;
  1337. }
  1338. // Read row size of key
  1339. uint64_t k_size_row_ref;
  1340. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1341. const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
  1342. if (k_size_row != k_size_row_ref) {
  1343. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1344. return false;
  1345. }
  1346. if (cell_count) {
  1347. // Read and set the keys for the whole cell range
  1348. ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1349. }
  1350. }
  1351. if (!this->v_trans) {
  1352. for (const auto & layer : layers) {
  1353. const uint32_t il = layer.il;
  1354. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1355. auto * v = layer.v_stream[strm];
  1356. // Read type of value
  1357. int32_t v_type_i_ref;
  1358. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1359. const int32_t v_type_i = (int32_t) v->type;
  1360. if (v_type_i != v_type_i_ref) {
  1361. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1362. return false;
  1363. }
  1364. // Read row size of value
  1365. uint64_t v_size_row_ref;
  1366. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1367. const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
  1368. if (v_size_row != v_size_row_ref) {
  1369. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1370. return false;
  1371. }
  1372. if (cell_count) {
  1373. // Read and set the values for the whole cell range
  1374. ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1375. }
  1376. }
  1377. } else {
  1378. // For each layer, read the values for each cell (transposed)
  1379. for (const auto & layer : layers) {
  1380. const uint32_t il = layer.il;
  1381. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1382. auto * v = layer.v_stream[strm];
  1383. // Read type of value
  1384. int32_t v_type_i_ref;
  1385. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1386. const int32_t v_type_i = (int32_t) v->type;
  1387. if (v_type_i != v_type_i_ref) {
  1388. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1389. return false;
  1390. }
  1391. // Read element size of value
  1392. uint32_t v_size_el_ref;
  1393. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1394. const size_t v_size_el = ggml_type_size(v->type);
  1395. if (v_size_el != v_size_el_ref) {
  1396. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1397. return false;
  1398. }
  1399. // Read GQA embedding size
  1400. uint32_t n_embd_v_gqa_ref;
  1401. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1402. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1403. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1404. return false;
  1405. }
  1406. if (cell_count) {
  1407. // For each row in the transposed matrix, read the values for the whole cell range
  1408. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1409. const size_t dst_offset = (head + j * cells.size()) * v_size_el;
  1410. ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1411. }
  1412. }
  1413. }
  1414. }
  1415. return true;
  1416. }
  1417. //
  1418. // llama_kv_cache_context
  1419. //
  1420. llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
  1421. llama_kv_cache_context::llama_kv_cache_context(
  1422. llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
  1423. n_kv = kv->get_size();
  1424. const uint32_t n_stream = kv->get_n_stream();
  1425. // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
  1426. sinfos.resize(1);
  1427. sinfos[0].s0 = 0;
  1428. sinfos[0].s1 = n_stream - 1;
  1429. sinfos[0].idxs.resize(n_stream);
  1430. for (uint32_t s = 0; s < n_stream; ++s) {
  1431. sinfos[0].strm.push_back(s);
  1432. sinfos[0].idxs[s].resize(1, 0);
  1433. }
  1434. }
  1435. llama_kv_cache_context::llama_kv_cache_context(
  1436. llama_kv_cache * kv,
  1437. llama_context * lctx,
  1438. bool do_shift,
  1439. stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
  1440. if (!do_shift && this->sc_info.empty()) {
  1441. status = LLAMA_MEMORY_STATUS_NO_UPDATE;
  1442. }
  1443. }
  1444. llama_kv_cache_context::llama_kv_cache_context(
  1445. llama_kv_cache * kv,
  1446. llama_kv_cache::slot_info_vec_t sinfos,
  1447. std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
  1448. }
  1449. llama_kv_cache_context::~llama_kv_cache_context() = default;
  1450. bool llama_kv_cache_context::next() {
  1451. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1452. if (++i_cur >= ubatches.size()) {
  1453. return false;
  1454. }
  1455. return true;
  1456. }
  1457. bool llama_kv_cache_context::apply() {
  1458. assert(!llama_memory_status_is_fail(status));
  1459. // no ubatches -> this is a KV cache update
  1460. if (ubatches.empty()) {
  1461. kv->update(lctx, do_shift, sc_info);
  1462. return true;
  1463. }
  1464. kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
  1465. n_kv = kv->get_n_kv(sinfos[i_cur]);
  1466. return true;
  1467. }
  1468. llama_memory_status llama_kv_cache_context::get_status() const {
  1469. return status;
  1470. }
  1471. const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
  1472. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1473. return ubatches[i_cur];
  1474. }
  1475. uint32_t llama_kv_cache_context::get_n_kv() const {
  1476. return n_kv;
  1477. }
  1478. ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
  1479. return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
  1480. }
  1481. ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
  1482. return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
  1483. }
  1484. ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
  1485. return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
  1486. }
  1487. ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
  1488. return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
  1489. }
  1490. ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1491. return kv->build_input_k_idxs(ctx, ubatch);
  1492. }
  1493. ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1494. return kv->build_input_v_idxs(ctx, ubatch);
  1495. }
  1496. void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
  1497. kv->set_input_k_shift(dst);
  1498. }
  1499. void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1500. kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
  1501. }
  1502. void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1503. kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
  1504. }
  1505. void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  1506. kv->set_input_kq_mask(dst, ubatch, causal_attn);
  1507. }
  1508. void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1509. kv->set_input_pos_bucket(dst, ubatch);
  1510. }
  1511. uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
  1512. // the FA kernels require padding to avoid extra runtime boundary checks
  1513. return cparams.flash_attn ? 256u : 32u;
  1514. }