llama-kv-cache-unified.cpp 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841
  1. #include "llama-kv-cache-unified.h"
  2. #include "llama-impl.h"
  3. #include "llama-io.h"
  4. #include "llama-model.h"
  5. #include "llama-context.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <limits>
  10. #include <map>
  11. #include <stdexcept>
  12. //
  13. // llama_kv_cache_unified
  14. //
  15. llama_kv_cache_unified::llama_kv_cache_unified(
  16. const llama_model & model,
  17. layer_filter_cb && filter,
  18. ggml_type type_k,
  19. ggml_type type_v,
  20. bool v_trans,
  21. bool offload,
  22. uint32_t kv_size,
  23. uint32_t n_seq_max,
  24. uint32_t n_pad,
  25. uint32_t n_swa,
  26. llama_swa_type swa_type) :
  27. model(model), hparams(model.hparams), v_trans(v_trans),
  28. n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  29. GGML_ASSERT(kv_size % n_pad == 0);
  30. // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
  31. auto n_layer_cache = hparams.n_layer;
  32. if (model.arch == LLM_ARCH_GEMMA3N) {
  33. n_layer_cache = 20;
  34. }
  35. // create a context for each buffer type
  36. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  37. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  38. auto it = ctx_map.find(buft);
  39. if (it == ctx_map.end()) {
  40. ggml_init_params params = {
  41. /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
  42. /*.mem_buffer =*/ NULL,
  43. /*.no_alloc =*/ true,
  44. };
  45. ggml_context * ctx = ggml_init(params);
  46. if (!ctx) {
  47. return nullptr;
  48. }
  49. ctx_map[buft] = ctx;
  50. ctxs.emplace_back(ctx);
  51. return ctx;
  52. }
  53. return it->second;
  54. };
  55. head = 0;
  56. cells.resize(kv_size);
  57. for (uint32_t il = 0; il < n_layer_cache; il++) {
  58. if (filter && !filter(il)) {
  59. LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
  60. continue;
  61. }
  62. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  63. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  64. const char * dev_name = "CPU";
  65. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  66. if (offload) {
  67. auto * dev = model.dev_layer(il);
  68. buft = ggml_backend_dev_buffer_type(dev);
  69. dev_name = ggml_backend_dev_name(dev);
  70. }
  71. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  72. ggml_context * ctx = ctx_for_buft(buft);
  73. if (!ctx) {
  74. throw std::runtime_error("failed to create ggml context for kv cache");
  75. }
  76. ggml_tensor * k;
  77. ggml_tensor * v;
  78. k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
  79. v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
  80. ggml_format_name(k, "cache_k_l%d", il);
  81. ggml_format_name(v, "cache_v_l%d", il);
  82. map_layer_ids[il] = layers.size();
  83. layers.push_back({ il, k, v });
  84. }
  85. // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
  86. if (model.arch == LLM_ARCH_GEMMA3N) {
  87. LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
  88. for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
  89. if (filter && !filter(il)) {
  90. LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
  91. continue;
  92. }
  93. const bool is_swa = hparams.is_swa(il);
  94. const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
  95. GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
  96. map_layer_ids[il] = map_layer_ids[il_reuse];
  97. LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
  98. }
  99. }
  100. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  101. for (auto it : ctx_map) {
  102. auto * buft = it.first;
  103. auto * ctx = it.second;
  104. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  105. if (!buf) {
  106. throw std::runtime_error("failed to allocate buffer for kv cache");
  107. }
  108. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  109. ggml_backend_buffer_clear(buf, 0);
  110. bufs.emplace_back(buf);
  111. }
  112. {
  113. const size_t memory_size_k = size_k_bytes();
  114. const size_t memory_size_v = size_v_bytes();
  115. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  116. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
  117. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  118. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  119. }
  120. const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
  121. debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
  122. }
  123. void llama_kv_cache_unified::clear(bool data) {
  124. cells.reset();
  125. head = 0;
  126. if (data) {
  127. for (auto & buf : bufs) {
  128. ggml_backend_buffer_clear(buf.get(), 0);
  129. }
  130. }
  131. }
  132. bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  133. uint32_t new_head = cells.size();
  134. if (p0 < 0) {
  135. p0 = 0;
  136. }
  137. if (p1 < 0) {
  138. p1 = std::numeric_limits<llama_pos>::max();
  139. }
  140. if (seq_id >= 0) {
  141. for (uint32_t i = 0; i < cells.size(); ++i) {
  142. if (!cells.pos_in(i, p0, p1)) {
  143. continue;
  144. }
  145. if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
  146. if (new_head == cells.size()) {
  147. new_head = i;
  148. }
  149. }
  150. }
  151. } else {
  152. // match any sequence
  153. for (uint32_t i = 0; i < cells.size(); ++i) {
  154. if (!cells.pos_in(i, p0, p1)) {
  155. continue;
  156. }
  157. cells.rm(i);
  158. if (new_head == cells.size()) {
  159. new_head = i;
  160. }
  161. }
  162. }
  163. // If we freed up a slot, set head to it so searching can start there.
  164. if (new_head != cells.size() && new_head < head) {
  165. head = new_head;
  166. }
  167. return true;
  168. }
  169. void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  170. if (seq_id_src == seq_id_dst) {
  171. return;
  172. }
  173. if (p0 < 0) {
  174. p0 = 0;
  175. }
  176. if (p1 < 0) {
  177. p1 = std::numeric_limits<llama_pos>::max();
  178. }
  179. for (uint32_t i = 0; i < cells.size(); ++i) {
  180. if (!cells.pos_in(i, p0, p1)) {
  181. continue;
  182. }
  183. if (cells.seq_has(i, seq_id_src)) {
  184. cells.seq_add(i, seq_id_dst);
  185. }
  186. }
  187. }
  188. void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
  189. uint32_t new_head = cells.size();
  190. for (uint32_t i = 0; i < cells.size(); ++i) {
  191. if (cells.seq_keep(i, seq_id)) {
  192. if (new_head == cells.size()) {
  193. new_head = i;
  194. }
  195. }
  196. }
  197. // If we freed up a slot, set head to it so searching can start there.
  198. if (new_head != cells.size() && new_head < head) {
  199. head = new_head;
  200. }
  201. }
  202. void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
  203. if (shift == 0) {
  204. return;
  205. }
  206. uint32_t new_head = cells.size();
  207. if (p0 < 0) {
  208. p0 = 0;
  209. }
  210. if (p1 < 0) {
  211. p1 = std::numeric_limits<llama_pos>::max();
  212. }
  213. // If there is no range then return early to avoid looping over all cells.
  214. if (p0 == p1) {
  215. return;
  216. }
  217. for (uint32_t i = 0; i < cells.size(); ++i) {
  218. if (!cells.pos_in(i, p0, p1)) {
  219. continue;
  220. }
  221. if (cells.seq_has(i, seq_id)) {
  222. if (cells.pos_add(i, shift)) {
  223. if (new_head == cells.size()) {
  224. new_head = i;
  225. }
  226. }
  227. }
  228. }
  229. // If we freed up a slot, set head to it so searching can start there.
  230. // Otherwise we just start the next search from the beginning.
  231. head = new_head != cells.size() ? new_head : 0;
  232. }
  233. void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  234. if (d == 1) {
  235. return;
  236. }
  237. if (p0 < 0) {
  238. p0 = 0;
  239. }
  240. if (p1 < 0) {
  241. p1 = std::numeric_limits<llama_pos>::max();
  242. }
  243. // If there is no range then return early to avoid looping over the cache.
  244. if (p0 == p1) {
  245. return;
  246. }
  247. for (uint32_t i = 0; i < cells.size(); ++i) {
  248. if (!cells.pos_in(i, p0, p1)) {
  249. continue;
  250. }
  251. if (cells.seq_has(i, seq_id)) {
  252. cells.pos_div(i, d);
  253. }
  254. }
  255. }
  256. llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
  257. return cells.seq_pos_min(seq_id);
  258. }
  259. llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
  260. return cells.seq_pos_max(seq_id);
  261. }
  262. llama_memory_context_ptr llama_kv_cache_unified::init_batch(
  263. llama_batch_allocr & balloc,
  264. uint32_t n_ubatch,
  265. bool embd_all) {
  266. GGML_UNUSED(embd_all);
  267. do {
  268. balloc.split_reset();
  269. std::vector<llama_ubatch> ubatches;
  270. while (true) {
  271. auto ubatch = balloc.split_simple(n_ubatch);
  272. if (ubatch.n_tokens == 0) {
  273. break;
  274. }
  275. ubatches.push_back(std::move(ubatch)); // NOLINT
  276. }
  277. auto heads = prepare(ubatches);
  278. if (heads.empty()) {
  279. break;
  280. }
  281. return std::make_unique<llama_kv_cache_unified_context>(
  282. this, std::move(heads), std::move(ubatches));
  283. } while (false);
  284. return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
  285. }
  286. llama_memory_context_ptr llama_kv_cache_unified::init_full() {
  287. return std::make_unique<llama_kv_cache_unified_context>(this);
  288. }
  289. llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
  290. bool do_shift = get_has_shift();
  291. defrag_info dinfo;
  292. // see if we need to defrag
  293. {
  294. bool do_defrag = optimize;
  295. const auto thold = lctx->get_cparams().defrag_thold;
  296. if (!do_defrag && thold > 0.0f) {
  297. const auto n_kv = cells.used_max_p1();
  298. // - do not defrag small contexts (i.e. < 2048 tokens)
  299. // - count the padding towards the number of used tokens
  300. const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
  301. if (fragmentation > thold) {
  302. LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
  303. do_defrag = true;
  304. }
  305. }
  306. if (do_defrag) {
  307. dinfo = defrag_prepare(lctx->graph_max_nodes());
  308. }
  309. }
  310. return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
  311. }
  312. llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
  313. llama_kv_cache_unified::ubatch_heads res;
  314. struct state {
  315. uint32_t head_old; // old position of the head, before placing the ubatch
  316. uint32_t head_new; // new position of the head, after placing the ubatch
  317. llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
  318. };
  319. // remember the old state of the cells so we can restore it in the end
  320. std::vector<state> states;
  321. bool success = true;
  322. for (const auto & ubatch : ubatches) {
  323. // only find a suitable slot for the ubatch. don't modify the cells yet
  324. const int32_t head_new = find_slot(ubatch);
  325. if (head_new < 0) {
  326. success = false;
  327. break;
  328. }
  329. // remeber the position that we found
  330. res.push_back(head_new);
  331. // store the old state of the cells in the recovery stack
  332. states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
  333. // now emplace the ubatch
  334. apply_ubatch(head_new, ubatch);
  335. }
  336. // iterate backwards and restore the cells to their original state
  337. for (auto it = states.rbegin(); it != states.rend(); ++it) {
  338. cells.set(it->head_new, it->cells);
  339. head = it->head_old;
  340. }
  341. if (!success) {
  342. return {};
  343. }
  344. return res;
  345. }
  346. bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
  347. bool updated = false;
  348. auto * sched = lctx->get_sched();
  349. if (do_shift) {
  350. if (!get_can_shift()) {
  351. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  352. }
  353. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  354. // apply K-shift if needed
  355. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  356. ggml_backend_sched_reset(sched);
  357. auto * gf = lctx->graph_init();
  358. auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
  359. if (!res) {
  360. LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
  361. return updated;
  362. }
  363. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  364. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
  365. return updated;
  366. }
  367. res->set_inputs(nullptr);
  368. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  369. LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
  370. return updated;
  371. }
  372. updated = true;
  373. }
  374. cells.reset_shift();
  375. }
  376. if (!dinfo.empty()) {
  377. LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
  378. // apply moves:
  379. {
  380. const auto n_kv = dinfo.ids.size();
  381. for (uint32_t i = 0; i < n_kv; ++i) {
  382. assert(dinfo.ids[i] <= n_kv);
  383. if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
  384. continue;
  385. }
  386. cells.mv(i, dinfo.ids[i]);
  387. }
  388. // reset the head so we can find the first free slot during the next ubatch
  389. head = 0;
  390. }
  391. ggml_backend_sched_reset(sched);
  392. auto * gf = lctx->graph_init();
  393. auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
  394. if (!res) {
  395. LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
  396. return updated;
  397. }
  398. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  399. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
  400. return updated;
  401. }
  402. res->set_inputs(nullptr);
  403. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  404. LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
  405. return updated;
  406. }
  407. updated = true;
  408. }
  409. return updated;
  410. }
  411. int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
  412. const uint32_t n_tokens = ubatch.n_tokens;
  413. uint32_t head_cur = this->head;
  414. // if we have enough unused cells before the current head ->
  415. // better to start searching from the beginning of the cache, hoping to fill it
  416. if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
  417. head_cur = 0;
  418. }
  419. if (n_tokens > cells.size()) {
  420. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
  421. return -1;
  422. }
  423. if (debug > 0) {
  424. LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
  425. if ((debug == 2 && n_swa > 0) || debug > 2) {
  426. std::string ss;
  427. for (uint32_t i = 0; i < cells.size(); ++i) {
  428. if (cells.is_empty(i)) {
  429. ss += '.';
  430. } else {
  431. assert(cells.seq_count(i) >= 1);
  432. if (cells.seq_count(i) == 1) {
  433. ss += std::to_string(cells.seq_get(i));
  434. } else {
  435. ss += 'M';
  436. }
  437. }
  438. if (i%256 == 255) {
  439. ss += " *";
  440. ss += '\n';
  441. }
  442. }
  443. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  444. }
  445. if ((debug == 2 && n_swa > 0) || debug > 2) {
  446. std::string ss;
  447. for (uint32_t i = 0; i < cells.size(); ++i) {
  448. std::string cur;
  449. if (cells.is_empty(i)) {
  450. cur = '.';
  451. } else {
  452. cur = std::to_string(cells.pos_get(i));
  453. }
  454. const int n = cur.size();
  455. for (int j = 0; j < 5 - n; ++j) {
  456. cur += ' ';
  457. }
  458. ss += cur;
  459. if (i%256 == 255) {
  460. ss += " *";
  461. }
  462. if (i%64 == 63) {
  463. ss += '\n';
  464. }
  465. }
  466. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  467. }
  468. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  469. if (cells.seq_pos_min(s) < 0) {
  470. continue;
  471. }
  472. LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
  473. }
  474. }
  475. uint32_t n_tested = 0;
  476. while (true) {
  477. if (head_cur + n_tokens > cells.size()) {
  478. n_tested += cells.size() - head_cur;
  479. head_cur = 0;
  480. continue;
  481. }
  482. bool found = true;
  483. for (uint32_t i = 0; i < n_tokens; i++) {
  484. //const llama_pos pos = ubatch.pos[i];
  485. //const llama_seq_id seq_id = ubatch.seq_id[i][0];
  486. // can we use this cell? either:
  487. // - the cell is empty
  488. // - the cell is occupied only by one sequence:
  489. // - (disabled) mask causally, if the sequence is the same as the one we are inserting
  490. // - mask SWA, using current max pos for that sequence in the cache
  491. // always insert in the cell with minimum pos
  492. bool can_use = cells.is_empty(head_cur + i);
  493. if (!can_use && cells.seq_count(head_cur + i) == 1) {
  494. const llama_pos pos_cell = cells.pos_get(head_cur + i);
  495. // (disabled) causal mask
  496. // note: it's better to purge any "future" tokens beforehand
  497. //if (cells.seq_has(head_cur + i, seq_id)) {
  498. // can_use = pos_cell >= pos;
  499. //}
  500. if (!can_use) {
  501. const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
  502. // SWA mask
  503. if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
  504. can_use = true;
  505. }
  506. }
  507. }
  508. if (!can_use) {
  509. found = false;
  510. head_cur += i + 1;
  511. n_tested += i + 1;
  512. break;
  513. }
  514. }
  515. if (found) {
  516. break;
  517. }
  518. if (n_tested >= cells.size()) {
  519. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  520. return -1;
  521. }
  522. }
  523. return head_cur;
  524. }
  525. void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
  526. // keep track of the max sequence position that we would overwrite with this ubatch
  527. // for non-SWA cache, this would be always empty
  528. llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
  529. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  530. seq_pos_max_rm[s] = -1;
  531. }
  532. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  533. if (!cells.is_empty(head_cur + i)) {
  534. assert(cells.seq_count(head_cur + i) == 1);
  535. const llama_seq_id seq_id = cells.seq_get(head_cur + i);
  536. const llama_pos pos = cells.pos_get(head_cur + i);
  537. seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
  538. cells.rm(head_cur + i);
  539. }
  540. cells.pos_set(head_cur + i, ubatch.pos[i]);
  541. for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
  542. cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
  543. }
  544. }
  545. // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
  546. // will be present in the cache. so we have to purge any position which is less than those we would overwrite
  547. // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
  548. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  549. if (seq_pos_max_rm[s] == -1) {
  550. continue;
  551. }
  552. if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
  553. LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
  554. __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
  555. seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
  556. }
  557. }
  558. // move the head at the end of the slot
  559. head = head_cur + ubatch.n_tokens;
  560. }
  561. bool llama_kv_cache_unified::get_can_shift() const {
  562. return true;
  563. }
  564. uint32_t llama_kv_cache_unified::get_size() const {
  565. return cells.size();
  566. }
  567. bool llama_kv_cache_unified::get_has_shift() const {
  568. return cells.get_has_shift();
  569. }
  570. uint32_t llama_kv_cache_unified::get_n_kv() const {
  571. return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
  572. }
  573. ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
  574. const int32_t ikv = map_layer_ids.at(il);
  575. auto * k = layers[ikv].k;
  576. return ggml_view_3d(ctx, k,
  577. hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
  578. ggml_row_size(k->type, hparams.n_embd_head_k),
  579. ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
  580. 0);
  581. }
  582. ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
  583. const int32_t ikv = map_layer_ids.at(il);
  584. auto * v = layers[ikv].v;
  585. if (!v_trans) {
  586. // note: v->nb[1] <= v->nb[2]
  587. return ggml_view_3d(ctx, v,
  588. hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
  589. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  590. ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
  591. 0);
  592. }
  593. // note: v->nb[1] > v->nb[2]
  594. return ggml_view_3d(ctx, v,
  595. n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
  596. ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
  597. ggml_row_size(v->type, v->ne[1]), // v->nb[2]
  598. 0);
  599. }
  600. ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
  601. const int32_t ikv = map_layer_ids.at(il);
  602. auto * k = layers[ikv].k;
  603. const int64_t n_tokens = k_cur->ne[2];
  604. ggml_tensor * k_view = ggml_view_1d(ctx, k,
  605. n_tokens*hparams.n_embd_k_gqa(il),
  606. ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
  607. return ggml_cpy(ctx, k_cur, k_view);
  608. }
  609. ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
  610. const int32_t ikv = map_layer_ids.at(il);
  611. auto * v = layers[ikv].v;
  612. const int64_t n_tokens = v_cur->ne[2];
  613. v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
  614. ggml_tensor * v_view = nullptr;
  615. if (!v_trans) {
  616. v_view = ggml_view_1d(ctx, v,
  617. n_tokens*hparams.n_embd_v_gqa(il),
  618. ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
  619. } else {
  620. // note: the V cache is transposed when not using flash attention
  621. v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
  622. (v->ne[1])*ggml_element_size(v),
  623. (head_cur)*ggml_element_size(v));
  624. v_cur = ggml_transpose(ctx, v_cur);
  625. }
  626. return ggml_cpy(ctx, v_cur, v_view);
  627. }
  628. void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  629. const uint32_t n_tokens = ubatch->n_tokens;
  630. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  631. float * data = (float *) dst->data;
  632. const int64_t n_kv = dst->ne[0];
  633. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  634. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  635. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  636. // Causal mask:
  637. // xxx-------
  638. // xxxx------
  639. // xxxxx-----
  640. // Non-causal mask:
  641. // xxxxx-----
  642. // xxxxx-----
  643. // xxxxx-----
  644. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  645. for (uint32_t h = 0; h < 1; ++h) {
  646. for (uint32_t i = 0; i < n_tokens; ++i) {
  647. const llama_seq_id seq_id = ubatch->seq_id[i][0];
  648. const llama_pos p1 = ubatch->pos[i];
  649. for (uint32_t j = 0; j < n_kv; ++j) {
  650. float f = 0.0f;
  651. bool masked = false;
  652. if (cells.is_empty(j)) {
  653. masked = true;
  654. } else {
  655. const llama_pos p0 = cells.pos_get(j);
  656. // mask the token if not the same sequence
  657. masked = masked || (!cells.seq_has(j, seq_id));
  658. // mask future tokens
  659. masked = masked || (causal_attn && p0 > p1);
  660. // apply SWA if any
  661. masked = masked || (is_masked_swa(p0, p1));
  662. if (!masked && hparams.use_alibi) {
  663. f = -std::abs(p0 - p1);
  664. }
  665. }
  666. if (masked) {
  667. f = -INFINITY;
  668. }
  669. data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
  670. }
  671. }
  672. // mask padded tokens
  673. if (data) {
  674. for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  675. for (uint32_t j = 0; j < n_kv; ++j) {
  676. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  677. }
  678. }
  679. }
  680. }
  681. }
  682. void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
  683. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  684. int32_t * data = (int32_t *) dst->data;
  685. for (uint32_t i = 0; i < cells.size(); ++i) {
  686. data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
  687. }
  688. }
  689. void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  690. const int64_t n_tokens = ubatch->n_tokens;
  691. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  692. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  693. int32_t * data = (int32_t *) dst->data;
  694. const int32_t n_kv = dst->ne[0];
  695. for (int h = 0; h < 1; ++h) {
  696. for (int i = 0; i < n_tokens; ++i) {
  697. for (int j = 0; j < n_kv; ++j) {
  698. // the position when the cells is empty is irrelevant - it will be masked out later in the attention
  699. const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
  700. data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
  701. }
  702. }
  703. }
  704. }
  705. size_t llama_kv_cache_unified::total_size() const {
  706. size_t size = 0;
  707. for (const auto & buf : bufs) {
  708. size += ggml_backend_buffer_get_size(buf.get());
  709. }
  710. return size;
  711. }
  712. size_t llama_kv_cache_unified::size_k_bytes() const {
  713. size_t size_k_bytes = 0;
  714. for (const auto & layer : layers) {
  715. size_k_bytes += ggml_nbytes(layer.k);
  716. }
  717. return size_k_bytes;
  718. }
  719. size_t llama_kv_cache_unified::size_v_bytes() const {
  720. size_t size_v_bytes = 0;
  721. for (const auto & layer : layers) {
  722. size_v_bytes += ggml_nbytes(layer.v);
  723. }
  724. return size_v_bytes;
  725. }
  726. ggml_tensor * llama_kv_cache_unified::build_rope_shift(
  727. const llama_cparams & cparams,
  728. ggml_context * ctx,
  729. ggml_tensor * cur,
  730. ggml_tensor * shift,
  731. ggml_tensor * factors,
  732. float freq_base,
  733. float freq_scale) const {
  734. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  735. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  736. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  737. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  738. const auto & n_rot = hparams.n_rot;
  739. const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
  740. // @ngxson : this is a workaround
  741. // for M-RoPE, we want to rotate the whole vector when doing KV shift
  742. // a normal RoPE should work, we just need to use the correct ordering
  743. // ref: https://github.com/ggml-org/llama.cpp/pull/13870
  744. ? LLAMA_ROPE_TYPE_NEOX
  745. : hparams.rope_type;
  746. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  747. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  748. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
  749. ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
  750. : cparams.yarn_attn_factor;
  751. ggml_tensor * tmp;
  752. if (ggml_is_quantized(cur->type)) {
  753. // dequantize to f32 -> RoPE -> quantize back
  754. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  755. tmp = ggml_rope_ext(ctx, tmp,
  756. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  757. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  758. tmp = ggml_cpy(ctx, tmp, cur);
  759. } else {
  760. // we rotate only the first n_rot dimensions
  761. tmp = ggml_rope_ext_inplace(ctx, cur,
  762. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  763. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  764. }
  765. return tmp;
  766. }
  767. class llm_graph_input_k_shift : public llm_graph_input_i {
  768. public:
  769. llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  770. virtual ~llm_graph_input_k_shift() = default;
  771. void set_input(const llama_ubatch * ubatch) override;
  772. ggml_tensor * k_shift; // I32 [kv_size]
  773. const llama_kv_cache_unified * kv_self;
  774. };
  775. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  776. GGML_UNUSED(ubatch);
  777. if (k_shift) {
  778. kv_self->set_input_k_shift(k_shift);
  779. }
  780. }
  781. llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
  782. const llama_cparams & cparams,
  783. ggml_context * ctx,
  784. ggml_cgraph * gf) const {
  785. auto res = std::make_unique<llm_graph_result>();
  786. const auto & n_embd_head_k = hparams.n_embd_head_k;
  787. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  788. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  789. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
  790. ggml_set_input(inp->k_shift);
  791. for (const auto & layer : layers) {
  792. const uint32_t il = layer.il;
  793. const int64_t n_head_kv = hparams.n_head_kv(il);
  794. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  795. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  796. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  797. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  798. ggml_tensor * k =
  799. ggml_view_3d(ctx, layer.k,
  800. n_embd_head_k, n_head_kv, cells.size(),
  801. ggml_row_size(layer.k->type, n_embd_head_k),
  802. ggml_row_size(layer.k->type, n_embd_k_gqa),
  803. 0);
  804. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  805. ggml_build_forward_expand(gf, cur);
  806. }
  807. res->add_input(std::move(inp));
  808. return res;
  809. }
  810. llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
  811. const llama_cparams & cparams,
  812. ggml_context * ctx,
  813. ggml_cgraph * gf,
  814. const defrag_info & dinfo) const {
  815. auto res = std::make_unique<llm_graph_result>();
  816. const auto & ids = dinfo.ids;
  817. #if 0
  818. // CPU defrag
  819. //
  820. // TODO: optimizations are possible:
  821. // - multiple threads
  822. // - avoid copying to the host memory when already there
  823. //
  824. // likely not worth the effort, as we have ggml_graph based defrag
  825. //
  826. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
  827. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
  828. const uint32_t kv_size = size;
  829. std::vector<uint8_t> buf_k;
  830. std::vector<uint8_t> buf_v;
  831. for (uint32_t il = 0; il < n_layer; ++il) {
  832. const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
  833. const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
  834. const size_t v_size_el = ggml_type_size(v_l[il]->type);
  835. const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
  836. buf_k.resize(k_size);
  837. buf_v.resize(v_size);
  838. ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
  839. ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
  840. // batch move [i, i+nm) to [id, id+nm)
  841. // note: cells can move only to a lower index
  842. for (uint32_t i = 0; i < n_kv; ++i) {
  843. const uint32_t id = ids[i];
  844. if (i == id || id == n_kv) {
  845. continue;
  846. }
  847. uint32_t nm = 1;
  848. while (i + nm < n_kv && ids[i + nm] == id + nm) {
  849. nm++;
  850. }
  851. // move keys
  852. {
  853. const int64_t os = i*k_size_row;
  854. const int64_t od = id*k_size_row;
  855. memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
  856. }
  857. // move values (note: they are transposed)
  858. {
  859. const int64_t os = i;
  860. const int64_t od = id;
  861. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  862. memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
  863. }
  864. }
  865. i += nm - 1;
  866. }
  867. ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
  868. ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
  869. }
  870. #else
  871. for (uint32_t i = 0; i < ids.size(); ++i) {
  872. const uint32_t id = ids[i];
  873. if (i == id || id == ids.size()) {
  874. continue;
  875. }
  876. uint32_t nm = 1;
  877. while (i + nm < ids.size() && ids[i + nm] == id + nm) {
  878. nm++;
  879. }
  880. for (const auto & layer : layers) {
  881. const uint32_t il = layer.il;
  882. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  883. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  884. ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
  885. n_embd_k_gqa, nm,
  886. ggml_row_size(layer.k->type, n_embd_k_gqa),
  887. ggml_row_size(layer.k->type, n_embd_k_gqa*i));
  888. ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
  889. n_embd_k_gqa, nm,
  890. ggml_row_size(layer.k->type, n_embd_k_gqa),
  891. ggml_row_size(layer.k->type, n_embd_k_gqa*id));
  892. ggml_tensor * view_v_src;
  893. ggml_tensor * view_v_dst;
  894. if (cparams.flash_attn) {
  895. // NOTE: the V cache is not transposed when using flash attention
  896. view_v_src = ggml_view_2d(ctx, layer.v,
  897. n_embd_v_gqa, nm,
  898. ggml_row_size(layer.v->type, n_embd_v_gqa),
  899. ggml_row_size(layer.v->type, n_embd_v_gqa*i));
  900. view_v_dst = ggml_view_2d(ctx, layer.v,
  901. n_embd_v_gqa, nm,
  902. ggml_row_size(layer.v->type, n_embd_v_gqa),
  903. ggml_row_size(layer.v->type, n_embd_v_gqa*id));
  904. } else {
  905. view_v_src = ggml_view_2d(ctx, layer.v,
  906. nm, n_embd_v_gqa,
  907. ggml_row_size(layer.v->type, cells.size()),
  908. ggml_row_size(layer.v->type, i));
  909. view_v_dst = ggml_view_2d(ctx, layer.v,
  910. nm, n_embd_v_gqa,
  911. ggml_row_size(layer.v->type, cells.size()),
  912. ggml_row_size(layer.v->type, id));
  913. }
  914. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
  915. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
  916. }
  917. i += nm - 1;
  918. }
  919. //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
  920. #endif
  921. return res;
  922. }
  923. llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
  924. const uint32_t n_layer = layers.size();
  925. const uint32_t n_kv = cells.used_max_p1();
  926. const uint32_t n_used = cells.get_used();
  927. assert(n_used <= n_kv);
  928. //const int64_t t_start = ggml_time_us();
  929. // number of cells moved
  930. uint32_t n_moves = 0;
  931. // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
  932. // - source view, destination view, copy operation
  933. // - x2 for keys and values
  934. //const uint32_t max_moves = max_nodes()/(6*n_layer);
  935. // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
  936. const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
  937. // determine which KV cells to move where
  938. defrag_info res;
  939. auto & ids = res.ids;
  940. ids.resize(n_kv, n_kv);
  941. for (uint32_t i0 = 0; i0 < n_used; ++i0) {
  942. if (!cells.is_empty(i0)) {
  943. ids[i0] = i0;
  944. continue;
  945. }
  946. // found a hole - fill it with data from the end of the cache
  947. uint32_t nh = 1;
  948. // determine the size of the hole
  949. while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
  950. nh++;
  951. }
  952. uint32_t nf = 0;
  953. uint32_t is = n_kv - 1;
  954. // starting from the end, find nh non-empty cells
  955. for (; is > i0; --is) {
  956. if (cells.is_empty(is) || ids[is] != n_kv) {
  957. continue;
  958. }
  959. // non-empty cell which is not yet moved
  960. nf++;
  961. if (nf == nh) {
  962. break;
  963. }
  964. }
  965. // this can only happen if `n_used` is not accurate, which would be a bug
  966. GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
  967. nf = 0;
  968. uint32_t i1 = is;
  969. // are we moving a continuous block of memory?
  970. bool cont = false;
  971. // should we stop searching for the next move?
  972. bool stop = false;
  973. // go back and move the nf cells to the hole
  974. for (; i1 < n_kv; ++i1) {
  975. if (cells.is_empty(i1) || ids[i1] != n_kv) {
  976. if (n_moves == max_moves) {
  977. stop = true;
  978. break;
  979. }
  980. cont = false;
  981. continue;
  982. }
  983. // this cell goes to (i0 + nf)
  984. ids[i1] = i0 + nf;
  985. if (!cont) {
  986. n_moves++;
  987. cont = true;
  988. }
  989. nf++;
  990. if (nf == nh) {
  991. break;
  992. }
  993. }
  994. if (stop || n_moves == max_moves) {
  995. break;
  996. }
  997. //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
  998. i0 += nh - 1;
  999. }
  1000. if (n_moves == 0) {
  1001. return {};
  1002. }
  1003. LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
  1004. LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
  1005. return res;
  1006. }
  1007. bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
  1008. assert(p0 >= 0 && p1 >= 0);
  1009. switch (swa_type) {
  1010. case LLAMA_SWA_TYPE_NONE:
  1011. {
  1012. } break;
  1013. case LLAMA_SWA_TYPE_STANDARD:
  1014. {
  1015. if (p1 - p0 >= (int32_t) n_swa) {
  1016. return true;
  1017. }
  1018. } break;
  1019. case LLAMA_SWA_TYPE_CHUNKED:
  1020. {
  1021. const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
  1022. if (p0 < pos_chunk_start) {
  1023. return true;
  1024. }
  1025. } break;
  1026. }
  1027. return false;
  1028. }
  1029. void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
  1030. std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
  1031. uint32_t cell_count = 0;
  1032. // Count the number of cells with the specified seq_id
  1033. // Find all the ranges of cells with this seq id (or all, when -1)
  1034. uint32_t cell_range_begin = cells.size();
  1035. for (uint32_t i = 0; i < cells.size(); ++i) {
  1036. if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
  1037. ++cell_count;
  1038. if (cell_range_begin == cells.size()) {
  1039. cell_range_begin = i;
  1040. }
  1041. } else {
  1042. if (cell_range_begin != cells.size()) {
  1043. cell_ranges.emplace_back(cell_range_begin, i);
  1044. cell_range_begin = cells.size();
  1045. }
  1046. }
  1047. }
  1048. if (cell_range_begin != cells.size()) {
  1049. cell_ranges.emplace_back(cell_range_begin, cells.size());
  1050. }
  1051. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1052. uint32_t cell_count_check = 0;
  1053. for (const auto & range : cell_ranges) {
  1054. cell_count_check += range.second - range.first;
  1055. }
  1056. GGML_ASSERT(cell_count == cell_count_check);
  1057. io.write(&cell_count, sizeof(cell_count));
  1058. state_write_meta(io, cell_ranges, seq_id);
  1059. state_write_data(io, cell_ranges);
  1060. }
  1061. void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
  1062. uint32_t cell_count;
  1063. io.read_to(&cell_count, sizeof(cell_count));
  1064. bool res = true;
  1065. res = res && state_read_meta(io, cell_count, seq_id);
  1066. res = res && state_read_data(io, cell_count);
  1067. if (!res) {
  1068. if (seq_id == -1) {
  1069. clear(true);
  1070. } else {
  1071. seq_rm(seq_id, -1, -1);
  1072. }
  1073. throw std::runtime_error("failed to restore kv cache");
  1074. }
  1075. }
  1076. void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
  1077. for (const auto & range : cell_ranges) {
  1078. for (uint32_t i = range.first; i < range.second; ++i) {
  1079. std::vector<llama_seq_id> seq_ids;
  1080. for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
  1081. if (cur == seq_id || seq_id == -1) {
  1082. if (cells.seq_has(i, cur)) {
  1083. seq_ids.push_back(cur);
  1084. }
  1085. }
  1086. }
  1087. const llama_pos pos = cells.pos_get(i);
  1088. const uint32_t n_seq_id = seq_ids.size();
  1089. io.write(&pos, sizeof(pos));
  1090. io.write(&n_seq_id, sizeof(n_seq_id));
  1091. for (const auto & seq_id : seq_ids) {
  1092. io.write(&seq_id, sizeof(seq_id));
  1093. }
  1094. }
  1095. }
  1096. }
  1097. void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
  1098. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1099. const uint32_t n_layer = layers.size();
  1100. io.write(&v_trans, sizeof(v_trans));
  1101. io.write(&n_layer, sizeof(n_layer));
  1102. std::vector<uint8_t> tmp_buf;
  1103. // Iterate and write all the keys first, each row is a cell
  1104. // Get whole range at a time
  1105. for (const auto & layer : layers) {
  1106. const uint32_t il = layer.il;
  1107. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1108. // Write key type
  1109. const int32_t k_type_i = (int32_t)layer.k->type;
  1110. io.write(&k_type_i, sizeof(k_type_i));
  1111. // Write row size of key
  1112. const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1113. io.write(&k_size_row, sizeof(k_size_row));
  1114. // Read each range of cells of k_size length each into tmp_buf and write out
  1115. for (const auto & range : cell_ranges) {
  1116. const size_t range_size = range.second - range.first;
  1117. const size_t buf_size = range_size * k_size_row;
  1118. io.write_tensor(layer.k, range.first * k_size_row, buf_size);
  1119. }
  1120. }
  1121. if (!v_trans) {
  1122. for (const auto & layer : layers) {
  1123. const uint32_t il = layer.il;
  1124. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1125. // Write value type
  1126. const int32_t v_type_i = (int32_t)layer.v->type;
  1127. io.write(&v_type_i, sizeof(v_type_i));
  1128. // Write row size of value
  1129. const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1130. io.write(&v_size_row, sizeof(v_size_row));
  1131. // Read each range of cells of v_size length each into tmp_buf and write out
  1132. for (const auto & range : cell_ranges) {
  1133. const size_t range_size = range.second - range.first;
  1134. const size_t buf_size = range_size * v_size_row;
  1135. io.write_tensor(layer.v, range.first * v_size_row, buf_size);
  1136. }
  1137. }
  1138. } else {
  1139. // When v is transposed, we also need the element size and get the element ranges from each row
  1140. const uint32_t kv_size = cells.size();
  1141. for (const auto & layer : layers) {
  1142. const uint32_t il = layer.il;
  1143. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1144. // Write value type
  1145. const int32_t v_type_i = (int32_t)layer.v->type;
  1146. io.write(&v_type_i, sizeof(v_type_i));
  1147. // Write element size
  1148. const uint32_t v_size_el = ggml_type_size(layer.v->type);
  1149. io.write(&v_size_el, sizeof(v_size_el));
  1150. // Write GQA embedding size
  1151. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1152. // For each row, we get the element values of each cell
  1153. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1154. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1155. for (const auto & range : cell_ranges) {
  1156. const size_t range_size = range.second - range.first;
  1157. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1158. const size_t buf_size = range_size * v_size_el;
  1159. io.write_tensor(layer.v, src_offset, buf_size);
  1160. }
  1161. }
  1162. }
  1163. }
  1164. }
  1165. bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1166. if (dest_seq_id != -1) {
  1167. // single sequence
  1168. seq_rm(dest_seq_id, -1, -1);
  1169. llama_batch_allocr balloc(hparams.n_pos_per_embd());
  1170. llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
  1171. for (uint32_t i = 0; i < cell_count; ++i) {
  1172. llama_pos pos;
  1173. uint32_t n_seq_id;
  1174. io.read_to(&pos, sizeof(pos));
  1175. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1176. if (n_seq_id != 1) {
  1177. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1178. return false;
  1179. }
  1180. // read the sequence id, but directly discard it - we will use dest_seq_id instead
  1181. {
  1182. llama_seq_id seq_id;
  1183. io.read_to(&seq_id, sizeof(seq_id));
  1184. }
  1185. ubatch.pos[i] = pos;
  1186. ubatch.n_seq_id[i] = n_seq_id;
  1187. ubatch.seq_id[i] = &dest_seq_id;
  1188. }
  1189. const auto head_cur = find_slot(ubatch);
  1190. if (head_cur < 0) {
  1191. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1192. return false;
  1193. }
  1194. apply_ubatch(head_cur, ubatch);
  1195. // keep the head at the old position because we will read the KV data into it in state_read_data()
  1196. head = head_cur;
  1197. // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1198. // Assume that this is one contiguous block of cells
  1199. GGML_ASSERT(head_cur + cell_count <= cells.size());
  1200. GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
  1201. GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
  1202. GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
  1203. GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
  1204. } else {
  1205. // whole KV cache restore
  1206. if (cell_count > cells.size()) {
  1207. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1208. return false;
  1209. }
  1210. clear(true);
  1211. for (uint32_t i = 0; i < cell_count; ++i) {
  1212. llama_pos pos;
  1213. uint32_t n_seq_id;
  1214. io.read_to(&pos, sizeof(pos));
  1215. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1216. cells.pos_set(i, pos);
  1217. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1218. llama_seq_id seq_id;
  1219. io.read_to(&seq_id, sizeof(seq_id));
  1220. if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
  1221. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
  1222. return false;
  1223. }
  1224. cells.seq_add(i, seq_id);
  1225. }
  1226. }
  1227. head = 0;
  1228. }
  1229. return true;
  1230. }
  1231. bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
  1232. uint32_t v_trans;
  1233. uint32_t n_layer;
  1234. io.read_to(&v_trans, sizeof(v_trans));
  1235. io.read_to(&n_layer, sizeof(n_layer));
  1236. if (n_layer != layers.size()) {
  1237. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1238. return false;
  1239. }
  1240. if (cell_count > cells.size()) {
  1241. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
  1242. return false;
  1243. }
  1244. if (this->v_trans != (bool) v_trans) {
  1245. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1246. return false;
  1247. }
  1248. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1249. for (const auto & layer : layers) {
  1250. const uint32_t il = layer.il;
  1251. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1252. // Read type of key
  1253. int32_t k_type_i_ref;
  1254. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1255. const int32_t k_type_i = (int32_t) layer.k->type;
  1256. if (k_type_i != k_type_i_ref) {
  1257. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1258. return false;
  1259. }
  1260. // Read row size of key
  1261. uint64_t k_size_row_ref;
  1262. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1263. const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1264. if (k_size_row != k_size_row_ref) {
  1265. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1266. return false;
  1267. }
  1268. if (cell_count) {
  1269. // Read and set the keys for the whole cell range
  1270. ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1271. }
  1272. }
  1273. if (!this->v_trans) {
  1274. for (const auto & layer : layers) {
  1275. const uint32_t il = layer.il;
  1276. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1277. // Read type of value
  1278. int32_t v_type_i_ref;
  1279. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1280. const int32_t v_type_i = (int32_t)layer.v->type;
  1281. if (v_type_i != v_type_i_ref) {
  1282. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1283. return false;
  1284. }
  1285. // Read row size of value
  1286. uint64_t v_size_row_ref;
  1287. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1288. const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1289. if (v_size_row != v_size_row_ref) {
  1290. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1291. return false;
  1292. }
  1293. if (cell_count) {
  1294. // Read and set the values for the whole cell range
  1295. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1296. }
  1297. }
  1298. } else {
  1299. // For each layer, read the values for each cell (transposed)
  1300. for (const auto & layer : layers) {
  1301. const uint32_t il = layer.il;
  1302. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1303. // Read type of value
  1304. int32_t v_type_i_ref;
  1305. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1306. const int32_t v_type_i = (int32_t)layer.v->type;
  1307. if (v_type_i != v_type_i_ref) {
  1308. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1309. return false;
  1310. }
  1311. // Read element size of value
  1312. uint32_t v_size_el_ref;
  1313. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1314. const size_t v_size_el = ggml_type_size(layer.v->type);
  1315. if (v_size_el != v_size_el_ref) {
  1316. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1317. return false;
  1318. }
  1319. // Read GQA embedding size
  1320. uint32_t n_embd_v_gqa_ref;
  1321. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1322. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1323. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1324. return false;
  1325. }
  1326. if (cell_count) {
  1327. // For each row in the transposed matrix, read the values for the whole cell range
  1328. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1329. const size_t dst_offset = (head + j * cells.size()) * v_size_el;
  1330. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1331. }
  1332. }
  1333. }
  1334. }
  1335. return true;
  1336. }
  1337. //
  1338. // llama_kv_cache_unified_context
  1339. //
  1340. llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
  1341. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1342. llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
  1343. n_kv = kv->get_size();
  1344. head = 0;
  1345. }
  1346. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1347. llama_kv_cache_unified * kv,
  1348. llama_context * lctx,
  1349. bool do_shift,
  1350. defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
  1351. if (!do_shift && this->dinfo.empty()) {
  1352. status = LLAMA_MEMORY_STATUS_NO_UPDATE;
  1353. }
  1354. }
  1355. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1356. llama_kv_cache_unified * kv,
  1357. llama_kv_cache_unified::ubatch_heads heads,
  1358. std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
  1359. }
  1360. llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
  1361. bool llama_kv_cache_unified_context::next() {
  1362. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1363. if (++i_next >= ubatches.size()) {
  1364. return false;
  1365. }
  1366. return true;
  1367. }
  1368. bool llama_kv_cache_unified_context::apply() {
  1369. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1370. // no ubatches -> this is a KV cache update
  1371. if (ubatches.empty()) {
  1372. kv->update(lctx, do_shift, dinfo);
  1373. return true;
  1374. }
  1375. kv->apply_ubatch(heads[i_next], ubatches[i_next]);
  1376. n_kv = kv->get_n_kv();
  1377. head = heads[i_next];
  1378. return true;
  1379. }
  1380. llama_memory_status llama_kv_cache_unified_context::get_status() const {
  1381. return status;
  1382. }
  1383. const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
  1384. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1385. return ubatches[i_next];
  1386. }
  1387. uint32_t llama_kv_cache_unified_context::get_n_kv() const {
  1388. return n_kv;
  1389. }
  1390. ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
  1391. return kv->get_k(ctx, il, n_kv);
  1392. }
  1393. ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
  1394. return kv->get_v(ctx, il, n_kv);
  1395. }
  1396. ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
  1397. return kv->cpy_k(ctx, k_cur, il, head);
  1398. }
  1399. ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
  1400. return kv->cpy_v(ctx, v_cur, il, head);
  1401. }
  1402. void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
  1403. kv->set_input_k_shift(dst);
  1404. }
  1405. void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  1406. kv->set_input_kq_mask(dst, ubatch, causal_attn);
  1407. }
  1408. void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1409. kv->set_input_pos_bucket(dst, ubatch);
  1410. }
  1411. uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
  1412. // the FA kernels require padding to avoid extra runtime boundary checks
  1413. return cparams.flash_attn ? 256u : 32u;
  1414. }