llama-kv-cache-unified.cpp 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990
  1. #include "llama-kv-cache-unified.h"
  2. #include "llama-impl.h"
  3. #include "llama-io.h"
  4. #include "llama-model.h"
  5. #include "llama-context.h"
  6. #include <algorithm>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <limits>
  10. #include <map>
  11. #include <stdexcept>
  12. //
  13. // llama_kv_cache_unified
  14. //
  15. llama_kv_cache_unified::llama_kv_cache_unified(
  16. const llama_model & model,
  17. layer_filter_cb && filter,
  18. ggml_type type_k,
  19. ggml_type type_v,
  20. bool v_trans,
  21. bool offload,
  22. uint32_t kv_size,
  23. uint32_t n_seq_max,
  24. uint32_t n_pad,
  25. uint32_t n_swa,
  26. llama_swa_type swa_type) :
  27. model(model), hparams(model.hparams), v_trans(v_trans),
  28. n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
  29. GGML_ASSERT(kv_size % n_pad == 0);
  30. // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
  31. auto n_layer_cache = hparams.n_layer;
  32. if (model.arch == LLM_ARCH_GEMMA3N) {
  33. n_layer_cache = 20;
  34. }
  35. // create a context for each buffer type
  36. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  37. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  38. auto it = ctx_map.find(buft);
  39. if (it == ctx_map.end()) {
  40. ggml_init_params params = {
  41. /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
  42. /*.mem_buffer =*/ NULL,
  43. /*.no_alloc =*/ true,
  44. };
  45. ggml_context * ctx = ggml_init(params);
  46. if (!ctx) {
  47. return nullptr;
  48. }
  49. ctx_map[buft] = ctx;
  50. ctxs.emplace_back(ctx);
  51. return ctx;
  52. }
  53. return it->second;
  54. };
  55. head = 0;
  56. cells.resize(kv_size);
  57. for (uint32_t il = 0; il < n_layer_cache; il++) {
  58. if (filter && !filter(il)) {
  59. LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
  60. continue;
  61. }
  62. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  63. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  64. const char * dev_name = "CPU";
  65. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  66. if (offload) {
  67. auto * dev = model.dev_layer(il);
  68. buft = ggml_backend_dev_buffer_type(dev);
  69. dev_name = ggml_backend_dev_name(dev);
  70. }
  71. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  72. ggml_context * ctx = ctx_for_buft(buft);
  73. if (!ctx) {
  74. throw std::runtime_error("failed to create ggml context for kv cache");
  75. }
  76. ggml_tensor * k;
  77. ggml_tensor * v;
  78. k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
  79. v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
  80. ggml_format_name(k, "cache_k_l%d", il);
  81. ggml_format_name(v, "cache_v_l%d", il);
  82. map_layer_ids[il] = layers.size();
  83. layers.push_back({ il, k, v });
  84. }
  85. // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
  86. if (model.arch == LLM_ARCH_GEMMA3N) {
  87. LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
  88. for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
  89. if (filter && !filter(il)) {
  90. LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
  91. continue;
  92. }
  93. const bool is_swa = hparams.is_swa(il);
  94. const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
  95. GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
  96. map_layer_ids[il] = map_layer_ids[il_reuse];
  97. LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
  98. }
  99. }
  100. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  101. for (auto it : ctx_map) {
  102. auto * buft = it.first;
  103. auto * ctx = it.second;
  104. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  105. if (!buf) {
  106. throw std::runtime_error("failed to allocate buffer for kv cache");
  107. }
  108. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  109. ggml_backend_buffer_clear(buf, 0);
  110. bufs.emplace_back(buf);
  111. }
  112. {
  113. const size_t memory_size_k = size_k_bytes();
  114. const size_t memory_size_v = size_v_bytes();
  115. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  116. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
  117. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  118. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  119. }
  120. const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
  121. debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
  122. const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
  123. supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
  124. if (!supports_set_rows) {
  125. LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
  126. }
  127. }
  128. void llama_kv_cache_unified::clear(bool data) {
  129. cells.reset();
  130. head = 0;
  131. if (data) {
  132. for (auto & buf : bufs) {
  133. ggml_backend_buffer_clear(buf.get(), 0);
  134. }
  135. }
  136. }
  137. bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  138. uint32_t new_head = cells.size();
  139. if (p0 < 0) {
  140. p0 = 0;
  141. }
  142. if (p1 < 0) {
  143. p1 = std::numeric_limits<llama_pos>::max();
  144. }
  145. if (seq_id >= 0) {
  146. for (uint32_t i = 0; i < cells.size(); ++i) {
  147. if (!cells.pos_in(i, p0, p1)) {
  148. continue;
  149. }
  150. if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
  151. if (new_head == cells.size()) {
  152. new_head = i;
  153. }
  154. }
  155. }
  156. } else {
  157. // match any sequence
  158. for (uint32_t i = 0; i < cells.size(); ++i) {
  159. if (!cells.pos_in(i, p0, p1)) {
  160. continue;
  161. }
  162. cells.rm(i);
  163. if (new_head == cells.size()) {
  164. new_head = i;
  165. }
  166. }
  167. }
  168. // If we freed up a slot, set head to it so searching can start there.
  169. if (new_head != cells.size() && new_head < head) {
  170. head = new_head;
  171. }
  172. return true;
  173. }
  174. void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  175. if (seq_id_src == seq_id_dst) {
  176. return;
  177. }
  178. if (p0 < 0) {
  179. p0 = 0;
  180. }
  181. if (p1 < 0) {
  182. p1 = std::numeric_limits<llama_pos>::max();
  183. }
  184. for (uint32_t i = 0; i < cells.size(); ++i) {
  185. if (!cells.pos_in(i, p0, p1)) {
  186. continue;
  187. }
  188. if (cells.seq_has(i, seq_id_src)) {
  189. cells.seq_add(i, seq_id_dst);
  190. }
  191. }
  192. }
  193. void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
  194. uint32_t new_head = cells.size();
  195. for (uint32_t i = 0; i < cells.size(); ++i) {
  196. if (cells.seq_keep(i, seq_id)) {
  197. if (new_head == cells.size()) {
  198. new_head = i;
  199. }
  200. }
  201. }
  202. // If we freed up a slot, set head to it so searching can start there.
  203. if (new_head != cells.size() && new_head < head) {
  204. head = new_head;
  205. }
  206. }
  207. void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
  208. if (shift == 0) {
  209. return;
  210. }
  211. uint32_t new_head = cells.size();
  212. if (p0 < 0) {
  213. p0 = 0;
  214. }
  215. if (p1 < 0) {
  216. p1 = std::numeric_limits<llama_pos>::max();
  217. }
  218. // If there is no range then return early to avoid looping over all cells.
  219. if (p0 == p1) {
  220. return;
  221. }
  222. for (uint32_t i = 0; i < cells.size(); ++i) {
  223. if (!cells.pos_in(i, p0, p1)) {
  224. continue;
  225. }
  226. if (cells.seq_has(i, seq_id)) {
  227. if (cells.pos_add(i, shift)) {
  228. if (new_head == cells.size()) {
  229. new_head = i;
  230. }
  231. }
  232. }
  233. }
  234. // If we freed up a slot, set head to it so searching can start there.
  235. // Otherwise we just start the next search from the beginning.
  236. head = new_head != cells.size() ? new_head : 0;
  237. }
  238. void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  239. if (d == 1) {
  240. return;
  241. }
  242. if (p0 < 0) {
  243. p0 = 0;
  244. }
  245. if (p1 < 0) {
  246. p1 = std::numeric_limits<llama_pos>::max();
  247. }
  248. // If there is no range then return early to avoid looping over the cache.
  249. if (p0 == p1) {
  250. return;
  251. }
  252. for (uint32_t i = 0; i < cells.size(); ++i) {
  253. if (!cells.pos_in(i, p0, p1)) {
  254. continue;
  255. }
  256. if (cells.seq_has(i, seq_id)) {
  257. cells.pos_div(i, d);
  258. }
  259. }
  260. }
  261. llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
  262. return cells.seq_pos_min(seq_id);
  263. }
  264. llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
  265. return cells.seq_pos_max(seq_id);
  266. }
  267. llama_memory_context_ptr llama_kv_cache_unified::init_batch(
  268. llama_batch_allocr & balloc,
  269. uint32_t n_ubatch,
  270. bool embd_all) {
  271. GGML_UNUSED(embd_all);
  272. do {
  273. balloc.split_reset();
  274. std::vector<llama_ubatch> ubatches;
  275. while (true) {
  276. auto ubatch = balloc.split_simple(n_ubatch);
  277. if (ubatch.n_tokens == 0) {
  278. break;
  279. }
  280. ubatches.push_back(std::move(ubatch)); // NOLINT
  281. }
  282. if (balloc.get_n_used() < balloc.get_n_tokens()) {
  283. // failed to find a suitable split
  284. break;
  285. }
  286. auto sinfos = prepare(ubatches);
  287. if (sinfos.empty()) {
  288. break;
  289. }
  290. return std::make_unique<llama_kv_cache_unified_context>(
  291. this, std::move(sinfos), std::move(ubatches));
  292. } while (false);
  293. return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
  294. }
  295. llama_memory_context_ptr llama_kv_cache_unified::init_full() {
  296. return std::make_unique<llama_kv_cache_unified_context>(this);
  297. }
  298. llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
  299. bool do_shift = get_has_shift();
  300. defrag_info dinfo;
  301. // see if we need to defrag
  302. {
  303. bool do_defrag = optimize;
  304. const auto thold = lctx->get_cparams().defrag_thold;
  305. if (!do_defrag && thold > 0.0f) {
  306. const auto n_kv = cells.used_max_p1();
  307. // - do not defrag small contexts (i.e. < 2048 tokens)
  308. // - count the padding towards the number of used tokens
  309. const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
  310. if (fragmentation > thold) {
  311. LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
  312. do_defrag = true;
  313. }
  314. }
  315. if (do_defrag) {
  316. dinfo = defrag_prepare(lctx->graph_max_nodes());
  317. }
  318. }
  319. return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
  320. }
  321. llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
  322. llama_kv_cache_unified::slot_info_vec_t res;
  323. struct state {
  324. uint32_t head_old; // old position of the head, before placing the ubatch
  325. slot_info sinfo; // slot info for the ubatch
  326. llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
  327. };
  328. // remember the old state of the cells so we can restore it in the end
  329. std::vector<state> states;
  330. bool success = true;
  331. for (const auto & ubatch : ubatches) {
  332. // non-continuous slots require support for ggml_set_rows()
  333. const bool cont = supports_set_rows ? false : true;
  334. // only find a suitable slot for the ubatch. don't modify the cells yet
  335. const auto sinfo_new = find_slot(ubatch, cont);
  336. if (sinfo_new.empty()) {
  337. success = false;
  338. break;
  339. }
  340. // remeber the position that we found
  341. res.push_back(sinfo_new);
  342. // store the old state of the cells in the recovery stack
  343. states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
  344. // now emplace the ubatch
  345. apply_ubatch(sinfo_new, ubatch);
  346. }
  347. // iterate backwards and restore the cells to their original state
  348. for (auto it = states.rbegin(); it != states.rend(); ++it) {
  349. cells.set(it->sinfo.idxs, it->cells);
  350. head = it->head_old;
  351. }
  352. if (!success) {
  353. return {};
  354. }
  355. return res;
  356. }
  357. bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
  358. bool updated = false;
  359. auto * sched = lctx->get_sched();
  360. if (do_shift) {
  361. if (!get_can_shift()) {
  362. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  363. }
  364. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  365. // apply K-shift if needed
  366. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  367. ggml_backend_sched_reset(sched);
  368. auto * gf = lctx->graph_init();
  369. auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
  370. if (!res) {
  371. LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
  372. return updated;
  373. }
  374. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  375. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
  376. return updated;
  377. }
  378. res->set_inputs(nullptr);
  379. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  380. LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
  381. return updated;
  382. }
  383. updated = true;
  384. }
  385. cells.reset_shift();
  386. }
  387. if (!dinfo.empty()) {
  388. LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
  389. // apply moves:
  390. {
  391. const auto n_kv = dinfo.ids.size();
  392. for (uint32_t i = 0; i < n_kv; ++i) {
  393. assert(dinfo.ids[i] <= n_kv);
  394. if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
  395. continue;
  396. }
  397. cells.mv(i, dinfo.ids[i]);
  398. }
  399. // reset the head so we can find the first free slot during the next ubatch
  400. head = 0;
  401. }
  402. ggml_backend_sched_reset(sched);
  403. auto * gf = lctx->graph_init();
  404. auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
  405. if (!res) {
  406. LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
  407. return updated;
  408. }
  409. if (!ggml_backend_sched_alloc_graph(sched, gf)) {
  410. LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
  411. return updated;
  412. }
  413. res->set_inputs(nullptr);
  414. if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
  415. LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
  416. return updated;
  417. }
  418. updated = true;
  419. }
  420. return updated;
  421. }
  422. llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
  423. const uint32_t n_tokens = ubatch.n_tokens;
  424. uint32_t head_cur = this->head;
  425. // if we have enough unused cells before the current head ->
  426. // better to start searching from the beginning of the cache, hoping to fill it
  427. if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
  428. head_cur = 0;
  429. }
  430. if (n_tokens > cells.size()) {
  431. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
  432. return { };
  433. }
  434. if (debug > 0) {
  435. LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
  436. if ((debug == 2 && n_swa > 0) || debug > 2) {
  437. std::string ss;
  438. for (uint32_t i = 0; i < cells.size(); ++i) {
  439. if (cells.is_empty(i)) {
  440. ss += '.';
  441. } else {
  442. assert(cells.seq_count(i) >= 1);
  443. if (cells.seq_count(i) == 1) {
  444. ss += std::to_string(cells.seq_get(i));
  445. } else {
  446. ss += 'M';
  447. }
  448. }
  449. if (i%256 == 255) {
  450. ss += " *";
  451. ss += '\n';
  452. }
  453. }
  454. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  455. }
  456. if ((debug == 2 && n_swa > 0) || debug > 2) {
  457. std::string ss;
  458. for (uint32_t i = 0; i < cells.size(); ++i) {
  459. std::string cur;
  460. if (cells.is_empty(i)) {
  461. cur = '.';
  462. } else {
  463. cur = std::to_string(cells.pos_get(i));
  464. }
  465. const int n = cur.size();
  466. for (int j = 0; j < 5 - n; ++j) {
  467. cur += ' ';
  468. }
  469. ss += cur;
  470. if (i%256 == 255) {
  471. ss += " *";
  472. }
  473. if (i%64 == 63) {
  474. ss += '\n';
  475. }
  476. }
  477. LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
  478. }
  479. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  480. if (cells.seq_pos_min(s) < 0) {
  481. continue;
  482. }
  483. LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
  484. }
  485. }
  486. uint32_t n_tested = 0;
  487. // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
  488. // for non-continuous slots, we test the tokens one by one
  489. const uint32_t n_test = cont ? n_tokens : 1;
  490. slot_info res;
  491. auto & idxs = res.idxs;
  492. idxs.reserve(n_tokens);
  493. while (true) {
  494. if (head_cur + n_test > cells.size()) {
  495. n_tested += cells.size() - head_cur;
  496. head_cur = 0;
  497. continue;
  498. }
  499. for (uint32_t i = 0; i < n_test; i++) {
  500. const auto idx = head_cur;
  501. //const llama_pos pos = ubatch.pos[i];
  502. //const llama_seq_id seq_id = ubatch.seq_id[i][0];
  503. // can we use this cell? either:
  504. // - the cell is empty
  505. // - the cell is occupied only by one sequence:
  506. // - (disabled) mask causally, if the sequence is the same as the one we are inserting
  507. // - mask SWA, using current max pos for that sequence in the cache
  508. // always insert in the cell with minimum pos
  509. bool can_use = cells.is_empty(idx);
  510. if (!can_use && cells.seq_count(idx) == 1) {
  511. const llama_pos pos_cell = cells.pos_get(idx);
  512. // (disabled) causal mask
  513. // note: it's better to purge any "future" tokens beforehand
  514. //if (cells.seq_has(idx, seq_id)) {
  515. // can_use = pos_cell >= pos;
  516. //}
  517. if (!can_use) {
  518. const llama_seq_id seq_id_cell = cells.seq_get(idx);
  519. // SWA mask
  520. if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
  521. can_use = true;
  522. }
  523. }
  524. }
  525. head_cur++;
  526. n_tested++;
  527. if (can_use) {
  528. idxs.push_back(idx);
  529. } else {
  530. break;
  531. }
  532. }
  533. if (idxs.size() == n_tokens) {
  534. break;
  535. }
  536. if (cont) {
  537. idxs.clear();
  538. }
  539. if (n_tested >= cells.size()) {
  540. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  541. return { };
  542. }
  543. }
  544. // we didn't find a suitable slot - return empty result
  545. if (idxs.size() < n_tokens) {
  546. res.clear();
  547. }
  548. return res;
  549. }
  550. void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
  551. // keep track of the max sequence position that we would overwrite with this ubatch
  552. // for non-SWA cache, this would be always empty
  553. llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
  554. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  555. seq_pos_max_rm[s] = -1;
  556. }
  557. assert(ubatch.n_tokens == sinfo.idxs.size());
  558. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  559. const auto idx = sinfo.idxs.at(i);
  560. if (!cells.is_empty(idx)) {
  561. assert(cells.seq_count(idx) == 1);
  562. const llama_seq_id seq_id = cells.seq_get(idx);
  563. const llama_pos pos = cells.pos_get(idx);
  564. seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
  565. cells.rm(idx);
  566. }
  567. cells.pos_set(idx, ubatch.pos[i]);
  568. for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
  569. cells.seq_add(idx, ubatch.seq_id[i][s]);
  570. }
  571. }
  572. // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
  573. // will be present in the cache. so we have to purge any position which is less than those we would overwrite
  574. // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
  575. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  576. if (seq_pos_max_rm[s] == -1) {
  577. continue;
  578. }
  579. if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
  580. LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
  581. __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
  582. seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
  583. }
  584. }
  585. // move the head at the end of the slot
  586. head = sinfo.idxs.back() + 1;
  587. }
  588. bool llama_kv_cache_unified::get_can_shift() const {
  589. return true;
  590. }
  591. uint32_t llama_kv_cache_unified::get_size() const {
  592. return cells.size();
  593. }
  594. bool llama_kv_cache_unified::get_has_shift() const {
  595. return cells.get_has_shift();
  596. }
  597. uint32_t llama_kv_cache_unified::get_n_kv() const {
  598. return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
  599. }
  600. ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
  601. const int32_t ikv = map_layer_ids.at(il);
  602. auto * k = layers[ikv].k;
  603. return ggml_view_3d(ctx, k,
  604. hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
  605. ggml_row_size(k->type, hparams.n_embd_head_k),
  606. ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
  607. 0);
  608. }
  609. ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
  610. const int32_t ikv = map_layer_ids.at(il);
  611. auto * v = layers[ikv].v;
  612. if (!v_trans) {
  613. // note: v->nb[1] <= v->nb[2]
  614. return ggml_view_3d(ctx, v,
  615. hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
  616. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  617. ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
  618. 0);
  619. }
  620. // note: v->nb[1] > v->nb[2]
  621. return ggml_view_3d(ctx, v,
  622. n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
  623. ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
  624. ggml_row_size(v->type, v->ne[1]), // v->nb[2]
  625. 0);
  626. }
  627. ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
  628. const int32_t ikv = map_layer_ids.at(il);
  629. auto * k = layers[ikv].k;
  630. const int64_t n_embd_k_gqa = k->ne[0];
  631. const int64_t n_tokens = k_cur->ne[2];
  632. k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
  633. if (k_idxs && supports_set_rows) {
  634. return ggml_set_rows(ctx, k, k_cur, k_idxs);
  635. }
  636. // TODO: fallback to old ggml_cpy() method for backwards compatibility
  637. // will be removed when ggml_set_rows() is adopted by all backends
  638. ggml_tensor * k_view = ggml_view_1d(ctx, k,
  639. n_tokens*n_embd_k_gqa,
  640. ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
  641. return ggml_cpy(ctx, k_cur, k_view);
  642. }
  643. ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
  644. const int32_t ikv = map_layer_ids.at(il);
  645. auto * v = layers[ikv].v;
  646. const int64_t n_embd_v_gqa = v->ne[0];
  647. const int64_t n_tokens = v_cur->ne[2];
  648. v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
  649. if (v_idxs && supports_set_rows) {
  650. if (!v_trans) {
  651. return ggml_set_rows(ctx, v, v_cur, v_idxs);
  652. }
  653. // the row becomes a single element
  654. ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
  655. // note: the V cache is transposed when not using flash attention
  656. v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
  657. // note: we can be more explicit here at the cost of extra cont
  658. // however, above we take advantage that a row of single element is always continuous regardless of the row stride
  659. //v_cur = ggml_transpose(ctx, v_cur);
  660. //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
  661. // we broadcast the KV indices n_embd_v_gqa times
  662. // v [1, n_kv, n_embd_v_gqa]
  663. // v_cur [1, n_tokens, n_embd_v_gqa]
  664. // v_idxs [n_tokens, 1, 1]
  665. return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
  666. }
  667. // TODO: fallback to old ggml_cpy() method for backwards compatibility
  668. // will be removed when ggml_set_rows() is adopted by all backends
  669. ggml_tensor * v_view = nullptr;
  670. if (!v_trans) {
  671. v_view = ggml_view_1d(ctx, v,
  672. n_tokens*n_embd_v_gqa,
  673. ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
  674. } else {
  675. v_cur = ggml_transpose(ctx, v_cur);
  676. v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
  677. (v->ne[1] )*ggml_element_size(v),
  678. (sinfo.head())*ggml_element_size(v));
  679. }
  680. return ggml_cpy(ctx, v_cur, v_view);
  681. }
  682. ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  683. const uint32_t n_tokens = ubatch.n_tokens;
  684. ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  685. ggml_set_input(k_idxs);
  686. return k_idxs;
  687. }
  688. ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  689. const uint32_t n_tokens = ubatch.n_tokens;
  690. ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
  691. ggml_set_input(v_idxs);
  692. return v_idxs;
  693. }
  694. void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  695. if (!supports_set_rows) {
  696. return;
  697. }
  698. const uint32_t n_tokens = ubatch->n_tokens;
  699. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  700. int64_t * data = (int64_t *) dst->data;
  701. for (int64_t i = 0; i < n_tokens; ++i) {
  702. data[i] = sinfo.idxs.at(i);
  703. }
  704. }
  705. void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
  706. if (!supports_set_rows) {
  707. return;
  708. }
  709. const uint32_t n_tokens = ubatch->n_tokens;
  710. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  711. int64_t * data = (int64_t *) dst->data;
  712. for (int64_t i = 0; i < n_tokens; ++i) {
  713. data[i] = sinfo.idxs.at(i);
  714. }
  715. }
  716. void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  717. const uint32_t n_tokens = ubatch->n_tokens;
  718. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  719. float * data = (float *) dst->data;
  720. const int64_t n_kv = dst->ne[0];
  721. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  722. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  723. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  724. // Causal mask:
  725. // xxx-------
  726. // xxxx------
  727. // xxxxx-----
  728. // Non-causal mask:
  729. // xxxxx-----
  730. // xxxxx-----
  731. // xxxxx-----
  732. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  733. for (uint32_t h = 0; h < 1; ++h) {
  734. for (uint32_t i = 0; i < n_tokens; ++i) {
  735. const llama_seq_id seq_id = ubatch->seq_id[i][0];
  736. const llama_pos p1 = ubatch->pos[i];
  737. for (uint32_t j = 0; j < n_kv; ++j) {
  738. float f = 0.0f;
  739. bool masked = false;
  740. if (cells.is_empty(j)) {
  741. masked = true;
  742. } else {
  743. const llama_pos p0 = cells.pos_get(j);
  744. // mask the token if not the same sequence
  745. masked = masked || (!cells.seq_has(j, seq_id));
  746. // mask future tokens
  747. masked = masked || (causal_attn && p0 > p1);
  748. // apply SWA if any
  749. masked = masked || (is_masked_swa(p0, p1));
  750. if (!masked && hparams.use_alibi) {
  751. f = -std::abs(p0 - p1);
  752. }
  753. }
  754. if (masked) {
  755. f = -INFINITY;
  756. }
  757. data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
  758. }
  759. }
  760. // mask padded tokens
  761. if (data) {
  762. for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  763. for (uint32_t j = 0; j < n_kv; ++j) {
  764. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  765. }
  766. }
  767. }
  768. }
  769. }
  770. void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
  771. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  772. int32_t * data = (int32_t *) dst->data;
  773. for (uint32_t i = 0; i < cells.size(); ++i) {
  774. data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
  775. }
  776. }
  777. void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  778. const int64_t n_tokens = ubatch->n_tokens;
  779. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  780. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  781. int32_t * data = (int32_t *) dst->data;
  782. const int32_t n_kv = dst->ne[0];
  783. for (int h = 0; h < 1; ++h) {
  784. for (int i = 0; i < n_tokens; ++i) {
  785. for (int j = 0; j < n_kv; ++j) {
  786. // the position when the cells is empty is irrelevant - it will be masked out later in the attention
  787. const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
  788. data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
  789. }
  790. }
  791. }
  792. }
  793. size_t llama_kv_cache_unified::total_size() const {
  794. size_t size = 0;
  795. for (const auto & buf : bufs) {
  796. size += ggml_backend_buffer_get_size(buf.get());
  797. }
  798. return size;
  799. }
  800. size_t llama_kv_cache_unified::size_k_bytes() const {
  801. size_t size_k_bytes = 0;
  802. for (const auto & layer : layers) {
  803. size_k_bytes += ggml_nbytes(layer.k);
  804. }
  805. return size_k_bytes;
  806. }
  807. size_t llama_kv_cache_unified::size_v_bytes() const {
  808. size_t size_v_bytes = 0;
  809. for (const auto & layer : layers) {
  810. size_v_bytes += ggml_nbytes(layer.v);
  811. }
  812. return size_v_bytes;
  813. }
  814. ggml_tensor * llama_kv_cache_unified::build_rope_shift(
  815. const llama_cparams & cparams,
  816. ggml_context * ctx,
  817. ggml_tensor * cur,
  818. ggml_tensor * shift,
  819. ggml_tensor * factors,
  820. float freq_base,
  821. float freq_scale) const {
  822. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  823. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  824. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  825. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  826. const auto & n_rot = hparams.n_rot;
  827. const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
  828. // @ngxson : this is a workaround
  829. // for M-RoPE, we want to rotate the whole vector when doing KV shift
  830. // a normal RoPE should work, we just need to use the correct ordering
  831. // ref: https://github.com/ggml-org/llama.cpp/pull/13870
  832. ? LLAMA_ROPE_TYPE_NEOX
  833. : hparams.rope_type;
  834. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  835. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  836. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
  837. ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
  838. : cparams.yarn_attn_factor;
  839. ggml_tensor * tmp;
  840. if (ggml_is_quantized(cur->type)) {
  841. // dequantize to f32 -> RoPE -> quantize back
  842. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  843. tmp = ggml_rope_ext(ctx, tmp,
  844. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  845. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  846. tmp = ggml_cpy(ctx, tmp, cur);
  847. } else {
  848. // we rotate only the first n_rot dimensions
  849. tmp = ggml_rope_ext_inplace(ctx, cur,
  850. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  851. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  852. }
  853. return tmp;
  854. }
  855. class llm_graph_input_k_shift : public llm_graph_input_i {
  856. public:
  857. llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  858. virtual ~llm_graph_input_k_shift() = default;
  859. void set_input(const llama_ubatch * ubatch) override;
  860. ggml_tensor * k_shift; // I32 [kv_size]
  861. const llama_kv_cache_unified * kv_self;
  862. };
  863. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  864. GGML_UNUSED(ubatch);
  865. if (k_shift) {
  866. kv_self->set_input_k_shift(k_shift);
  867. }
  868. }
  869. llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
  870. const llama_cparams & cparams,
  871. ggml_context * ctx,
  872. ggml_cgraph * gf) const {
  873. auto res = std::make_unique<llm_graph_result>();
  874. const auto & n_embd_head_k = hparams.n_embd_head_k;
  875. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  876. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  877. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
  878. ggml_set_input(inp->k_shift);
  879. for (const auto & layer : layers) {
  880. const uint32_t il = layer.il;
  881. const int64_t n_head_kv = hparams.n_head_kv(il);
  882. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  883. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  884. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  885. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  886. ggml_tensor * k =
  887. ggml_view_3d(ctx, layer.k,
  888. n_embd_head_k, n_head_kv, cells.size(),
  889. ggml_row_size(layer.k->type, n_embd_head_k),
  890. ggml_row_size(layer.k->type, n_embd_k_gqa),
  891. 0);
  892. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  893. ggml_build_forward_expand(gf, cur);
  894. }
  895. res->add_input(std::move(inp));
  896. return res;
  897. }
  898. llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
  899. const llama_cparams & cparams,
  900. ggml_context * ctx,
  901. ggml_cgraph * gf,
  902. const defrag_info & dinfo) const {
  903. auto res = std::make_unique<llm_graph_result>();
  904. const auto & ids = dinfo.ids;
  905. #if 0
  906. // CPU defrag
  907. //
  908. // TODO: optimizations are possible:
  909. // - multiple threads
  910. // - avoid copying to the host memory when already there
  911. //
  912. // likely not worth the effort, as we have ggml_graph based defrag
  913. //
  914. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
  915. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
  916. const uint32_t kv_size = size;
  917. std::vector<uint8_t> buf_k;
  918. std::vector<uint8_t> buf_v;
  919. for (uint32_t il = 0; il < n_layer; ++il) {
  920. const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
  921. const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
  922. const size_t v_size_el = ggml_type_size(v_l[il]->type);
  923. const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
  924. buf_k.resize(k_size);
  925. buf_v.resize(v_size);
  926. ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
  927. ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
  928. // batch move [i, i+nm) to [id, id+nm)
  929. // note: cells can move only to a lower index
  930. for (uint32_t i = 0; i < n_kv; ++i) {
  931. const uint32_t id = ids[i];
  932. if (i == id || id == n_kv) {
  933. continue;
  934. }
  935. uint32_t nm = 1;
  936. while (i + nm < n_kv && ids[i + nm] == id + nm) {
  937. nm++;
  938. }
  939. // move keys
  940. {
  941. const int64_t os = i*k_size_row;
  942. const int64_t od = id*k_size_row;
  943. memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
  944. }
  945. // move values (note: they are transposed)
  946. {
  947. const int64_t os = i;
  948. const int64_t od = id;
  949. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  950. memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
  951. }
  952. }
  953. i += nm - 1;
  954. }
  955. ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
  956. ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
  957. }
  958. #else
  959. for (uint32_t i = 0; i < ids.size(); ++i) {
  960. const uint32_t id = ids[i];
  961. if (i == id || id == ids.size()) {
  962. continue;
  963. }
  964. uint32_t nm = 1;
  965. while (i + nm < ids.size() && ids[i + nm] == id + nm) {
  966. nm++;
  967. }
  968. for (const auto & layer : layers) {
  969. const uint32_t il = layer.il;
  970. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  971. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  972. ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
  973. n_embd_k_gqa, nm,
  974. ggml_row_size(layer.k->type, n_embd_k_gqa),
  975. ggml_row_size(layer.k->type, n_embd_k_gqa*i));
  976. ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
  977. n_embd_k_gqa, nm,
  978. ggml_row_size(layer.k->type, n_embd_k_gqa),
  979. ggml_row_size(layer.k->type, n_embd_k_gqa*id));
  980. ggml_tensor * view_v_src;
  981. ggml_tensor * view_v_dst;
  982. if (cparams.flash_attn) {
  983. // NOTE: the V cache is not transposed when using flash attention
  984. view_v_src = ggml_view_2d(ctx, layer.v,
  985. n_embd_v_gqa, nm,
  986. ggml_row_size(layer.v->type, n_embd_v_gqa),
  987. ggml_row_size(layer.v->type, n_embd_v_gqa*i));
  988. view_v_dst = ggml_view_2d(ctx, layer.v,
  989. n_embd_v_gqa, nm,
  990. ggml_row_size(layer.v->type, n_embd_v_gqa),
  991. ggml_row_size(layer.v->type, n_embd_v_gqa*id));
  992. } else {
  993. view_v_src = ggml_view_2d(ctx, layer.v,
  994. nm, n_embd_v_gqa,
  995. ggml_row_size(layer.v->type, cells.size()),
  996. ggml_row_size(layer.v->type, i));
  997. view_v_dst = ggml_view_2d(ctx, layer.v,
  998. nm, n_embd_v_gqa,
  999. ggml_row_size(layer.v->type, cells.size()),
  1000. ggml_row_size(layer.v->type, id));
  1001. }
  1002. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
  1003. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
  1004. }
  1005. i += nm - 1;
  1006. }
  1007. //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
  1008. #endif
  1009. return res;
  1010. }
  1011. llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
  1012. const uint32_t n_layer = layers.size();
  1013. const uint32_t n_kv = cells.used_max_p1();
  1014. const uint32_t n_used = cells.get_used();
  1015. assert(n_used <= n_kv);
  1016. //const int64_t t_start = ggml_time_us();
  1017. // number of cells moved
  1018. uint32_t n_moves = 0;
  1019. // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
  1020. // - source view, destination view, copy operation
  1021. // - x2 for keys and values
  1022. //const uint32_t max_moves = max_nodes()/(6*n_layer);
  1023. // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
  1024. const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
  1025. // determine which KV cells to move where
  1026. defrag_info res;
  1027. auto & ids = res.ids;
  1028. ids.resize(n_kv, n_kv);
  1029. for (uint32_t i0 = 0; i0 < n_used; ++i0) {
  1030. if (!cells.is_empty(i0)) {
  1031. ids[i0] = i0;
  1032. continue;
  1033. }
  1034. // found a hole - fill it with data from the end of the cache
  1035. uint32_t nh = 1;
  1036. // determine the size of the hole
  1037. while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
  1038. nh++;
  1039. }
  1040. uint32_t nf = 0;
  1041. uint32_t is = n_kv - 1;
  1042. // starting from the end, find nh non-empty cells
  1043. for (; is > i0; --is) {
  1044. if (cells.is_empty(is) || ids[is] != n_kv) {
  1045. continue;
  1046. }
  1047. // non-empty cell which is not yet moved
  1048. nf++;
  1049. if (nf == nh) {
  1050. break;
  1051. }
  1052. }
  1053. // this can only happen if `n_used` is not accurate, which would be a bug
  1054. GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
  1055. nf = 0;
  1056. uint32_t i1 = is;
  1057. // are we moving a continuous block of memory?
  1058. bool cont = false;
  1059. // should we stop searching for the next move?
  1060. bool stop = false;
  1061. // go back and move the nf cells to the hole
  1062. for (; i1 < n_kv; ++i1) {
  1063. if (cells.is_empty(i1) || ids[i1] != n_kv) {
  1064. if (n_moves == max_moves) {
  1065. stop = true;
  1066. break;
  1067. }
  1068. cont = false;
  1069. continue;
  1070. }
  1071. // this cell goes to (i0 + nf)
  1072. ids[i1] = i0 + nf;
  1073. if (!cont) {
  1074. n_moves++;
  1075. cont = true;
  1076. }
  1077. nf++;
  1078. if (nf == nh) {
  1079. break;
  1080. }
  1081. }
  1082. if (stop || n_moves == max_moves) {
  1083. break;
  1084. }
  1085. //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
  1086. i0 += nh - 1;
  1087. }
  1088. if (n_moves == 0) {
  1089. return {};
  1090. }
  1091. LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
  1092. LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
  1093. return res;
  1094. }
  1095. bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
  1096. assert(p0 >= 0 && p1 >= 0);
  1097. switch (swa_type) {
  1098. case LLAMA_SWA_TYPE_NONE:
  1099. {
  1100. } break;
  1101. case LLAMA_SWA_TYPE_STANDARD:
  1102. {
  1103. if (p1 - p0 >= (int32_t) n_swa) {
  1104. return true;
  1105. }
  1106. } break;
  1107. case LLAMA_SWA_TYPE_CHUNKED:
  1108. {
  1109. const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
  1110. if (p0 < pos_chunk_start) {
  1111. return true;
  1112. }
  1113. } break;
  1114. }
  1115. return false;
  1116. }
  1117. void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
  1118. std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
  1119. uint32_t cell_count = 0;
  1120. // Count the number of cells with the specified seq_id
  1121. // Find all the ranges of cells with this seq id (or all, when -1)
  1122. uint32_t cell_range_begin = cells.size();
  1123. for (uint32_t i = 0; i < cells.size(); ++i) {
  1124. if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
  1125. ++cell_count;
  1126. if (cell_range_begin == cells.size()) {
  1127. cell_range_begin = i;
  1128. }
  1129. } else {
  1130. if (cell_range_begin != cells.size()) {
  1131. cell_ranges.emplace_back(cell_range_begin, i);
  1132. cell_range_begin = cells.size();
  1133. }
  1134. }
  1135. }
  1136. if (cell_range_begin != cells.size()) {
  1137. cell_ranges.emplace_back(cell_range_begin, cells.size());
  1138. }
  1139. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1140. uint32_t cell_count_check = 0;
  1141. for (const auto & range : cell_ranges) {
  1142. cell_count_check += range.second - range.first;
  1143. }
  1144. GGML_ASSERT(cell_count == cell_count_check);
  1145. io.write(&cell_count, sizeof(cell_count));
  1146. state_write_meta(io, cell_ranges, seq_id);
  1147. state_write_data(io, cell_ranges);
  1148. }
  1149. void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
  1150. uint32_t cell_count;
  1151. io.read_to(&cell_count, sizeof(cell_count));
  1152. bool res = true;
  1153. res = res && state_read_meta(io, cell_count, seq_id);
  1154. res = res && state_read_data(io, cell_count);
  1155. if (!res) {
  1156. if (seq_id == -1) {
  1157. clear(true);
  1158. } else {
  1159. seq_rm(seq_id, -1, -1);
  1160. }
  1161. throw std::runtime_error("failed to restore kv cache");
  1162. }
  1163. }
  1164. void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
  1165. for (const auto & range : cell_ranges) {
  1166. for (uint32_t i = range.first; i < range.second; ++i) {
  1167. std::vector<llama_seq_id> seq_ids;
  1168. for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
  1169. if (cur == seq_id || seq_id == -1) {
  1170. if (cells.seq_has(i, cur)) {
  1171. seq_ids.push_back(cur);
  1172. }
  1173. }
  1174. }
  1175. const llama_pos pos = cells.pos_get(i);
  1176. const uint32_t n_seq_id = seq_ids.size();
  1177. io.write(&pos, sizeof(pos));
  1178. io.write(&n_seq_id, sizeof(n_seq_id));
  1179. for (const auto & seq_id : seq_ids) {
  1180. io.write(&seq_id, sizeof(seq_id));
  1181. }
  1182. }
  1183. }
  1184. }
  1185. void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
  1186. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1187. const uint32_t n_layer = layers.size();
  1188. io.write(&v_trans, sizeof(v_trans));
  1189. io.write(&n_layer, sizeof(n_layer));
  1190. std::vector<uint8_t> tmp_buf;
  1191. // Iterate and write all the keys first, each row is a cell
  1192. // Get whole range at a time
  1193. for (const auto & layer : layers) {
  1194. const uint32_t il = layer.il;
  1195. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1196. // Write key type
  1197. const int32_t k_type_i = (int32_t)layer.k->type;
  1198. io.write(&k_type_i, sizeof(k_type_i));
  1199. // Write row size of key
  1200. const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1201. io.write(&k_size_row, sizeof(k_size_row));
  1202. // Read each range of cells of k_size length each into tmp_buf and write out
  1203. for (const auto & range : cell_ranges) {
  1204. const size_t range_size = range.second - range.first;
  1205. const size_t buf_size = range_size * k_size_row;
  1206. io.write_tensor(layer.k, range.first * k_size_row, buf_size);
  1207. }
  1208. }
  1209. if (!v_trans) {
  1210. for (const auto & layer : layers) {
  1211. const uint32_t il = layer.il;
  1212. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1213. // Write value type
  1214. const int32_t v_type_i = (int32_t)layer.v->type;
  1215. io.write(&v_type_i, sizeof(v_type_i));
  1216. // Write row size of value
  1217. const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1218. io.write(&v_size_row, sizeof(v_size_row));
  1219. // Read each range of cells of v_size length each into tmp_buf and write out
  1220. for (const auto & range : cell_ranges) {
  1221. const size_t range_size = range.second - range.first;
  1222. const size_t buf_size = range_size * v_size_row;
  1223. io.write_tensor(layer.v, range.first * v_size_row, buf_size);
  1224. }
  1225. }
  1226. } else {
  1227. // When v is transposed, we also need the element size and get the element ranges from each row
  1228. const uint32_t kv_size = cells.size();
  1229. for (const auto & layer : layers) {
  1230. const uint32_t il = layer.il;
  1231. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1232. // Write value type
  1233. const int32_t v_type_i = (int32_t)layer.v->type;
  1234. io.write(&v_type_i, sizeof(v_type_i));
  1235. // Write element size
  1236. const uint32_t v_size_el = ggml_type_size(layer.v->type);
  1237. io.write(&v_size_el, sizeof(v_size_el));
  1238. // Write GQA embedding size
  1239. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1240. // For each row, we get the element values of each cell
  1241. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1242. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1243. for (const auto & range : cell_ranges) {
  1244. const size_t range_size = range.second - range.first;
  1245. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1246. const size_t buf_size = range_size * v_size_el;
  1247. io.write_tensor(layer.v, src_offset, buf_size);
  1248. }
  1249. }
  1250. }
  1251. }
  1252. }
  1253. bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1254. if (dest_seq_id != -1) {
  1255. // single sequence
  1256. seq_rm(dest_seq_id, -1, -1);
  1257. llama_batch_allocr balloc(hparams.n_pos_per_embd());
  1258. llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
  1259. for (uint32_t i = 0; i < cell_count; ++i) {
  1260. llama_pos pos;
  1261. uint32_t n_seq_id;
  1262. io.read_to(&pos, sizeof(pos));
  1263. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1264. if (n_seq_id != 1) {
  1265. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1266. return false;
  1267. }
  1268. // read the sequence id, but directly discard it - we will use dest_seq_id instead
  1269. {
  1270. llama_seq_id seq_id;
  1271. io.read_to(&seq_id, sizeof(seq_id));
  1272. }
  1273. ubatch.pos[i] = pos;
  1274. ubatch.n_seq_id[i] = n_seq_id;
  1275. ubatch.seq_id[i] = &dest_seq_id;
  1276. }
  1277. const auto sinfo = find_slot(ubatch, true);
  1278. if (sinfo.empty()) {
  1279. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1280. return false;
  1281. }
  1282. apply_ubatch(sinfo, ubatch);
  1283. const auto head_cur = sinfo.head();
  1284. // keep the head at the old position because we will read the KV data into it in state_read_data()
  1285. head = head_cur;
  1286. // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1287. // Assume that this is one contiguous block of cells
  1288. GGML_ASSERT(head_cur + cell_count <= cells.size());
  1289. GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
  1290. GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
  1291. GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
  1292. GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
  1293. } else {
  1294. // whole KV cache restore
  1295. if (cell_count > cells.size()) {
  1296. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1297. return false;
  1298. }
  1299. clear(true);
  1300. for (uint32_t i = 0; i < cell_count; ++i) {
  1301. llama_pos pos;
  1302. uint32_t n_seq_id;
  1303. io.read_to(&pos, sizeof(pos));
  1304. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1305. cells.pos_set(i, pos);
  1306. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1307. llama_seq_id seq_id;
  1308. io.read_to(&seq_id, sizeof(seq_id));
  1309. if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
  1310. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
  1311. return false;
  1312. }
  1313. cells.seq_add(i, seq_id);
  1314. }
  1315. }
  1316. head = 0;
  1317. }
  1318. return true;
  1319. }
  1320. bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
  1321. uint32_t v_trans;
  1322. uint32_t n_layer;
  1323. io.read_to(&v_trans, sizeof(v_trans));
  1324. io.read_to(&n_layer, sizeof(n_layer));
  1325. if (n_layer != layers.size()) {
  1326. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1327. return false;
  1328. }
  1329. if (cell_count > cells.size()) {
  1330. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
  1331. return false;
  1332. }
  1333. if (this->v_trans != (bool) v_trans) {
  1334. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1335. return false;
  1336. }
  1337. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1338. for (const auto & layer : layers) {
  1339. const uint32_t il = layer.il;
  1340. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  1341. // Read type of key
  1342. int32_t k_type_i_ref;
  1343. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1344. const int32_t k_type_i = (int32_t) layer.k->type;
  1345. if (k_type_i != k_type_i_ref) {
  1346. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1347. return false;
  1348. }
  1349. // Read row size of key
  1350. uint64_t k_size_row_ref;
  1351. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1352. const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1353. if (k_size_row != k_size_row_ref) {
  1354. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1355. return false;
  1356. }
  1357. if (cell_count) {
  1358. // Read and set the keys for the whole cell range
  1359. ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1360. }
  1361. }
  1362. if (!this->v_trans) {
  1363. for (const auto & layer : layers) {
  1364. const uint32_t il = layer.il;
  1365. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1366. // Read type of value
  1367. int32_t v_type_i_ref;
  1368. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1369. const int32_t v_type_i = (int32_t)layer.v->type;
  1370. if (v_type_i != v_type_i_ref) {
  1371. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1372. return false;
  1373. }
  1374. // Read row size of value
  1375. uint64_t v_size_row_ref;
  1376. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1377. const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1378. if (v_size_row != v_size_row_ref) {
  1379. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1380. return false;
  1381. }
  1382. if (cell_count) {
  1383. // Read and set the values for the whole cell range
  1384. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1385. }
  1386. }
  1387. } else {
  1388. // For each layer, read the values for each cell (transposed)
  1389. for (const auto & layer : layers) {
  1390. const uint32_t il = layer.il;
  1391. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  1392. // Read type of value
  1393. int32_t v_type_i_ref;
  1394. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1395. const int32_t v_type_i = (int32_t)layer.v->type;
  1396. if (v_type_i != v_type_i_ref) {
  1397. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1398. return false;
  1399. }
  1400. // Read element size of value
  1401. uint32_t v_size_el_ref;
  1402. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1403. const size_t v_size_el = ggml_type_size(layer.v->type);
  1404. if (v_size_el != v_size_el_ref) {
  1405. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1406. return false;
  1407. }
  1408. // Read GQA embedding size
  1409. uint32_t n_embd_v_gqa_ref;
  1410. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1411. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1412. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1413. return false;
  1414. }
  1415. if (cell_count) {
  1416. // For each row in the transposed matrix, read the values for the whole cell range
  1417. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1418. const size_t dst_offset = (head + j * cells.size()) * v_size_el;
  1419. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1420. }
  1421. }
  1422. }
  1423. }
  1424. return true;
  1425. }
  1426. //
  1427. // llama_kv_cache_unified_context
  1428. //
  1429. llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
  1430. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1431. llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
  1432. n_kv = kv->get_size();
  1433. // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
  1434. sinfos.resize(1);
  1435. sinfos[0].idxs.resize(1);
  1436. sinfos[0].idxs[0] = 0;
  1437. }
  1438. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1439. llama_kv_cache_unified * kv,
  1440. llama_context * lctx,
  1441. bool do_shift,
  1442. defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
  1443. if (!do_shift && this->dinfo.empty()) {
  1444. status = LLAMA_MEMORY_STATUS_NO_UPDATE;
  1445. }
  1446. }
  1447. llama_kv_cache_unified_context::llama_kv_cache_unified_context(
  1448. llama_kv_cache_unified * kv,
  1449. llama_kv_cache_unified::slot_info_vec_t sinfos,
  1450. std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
  1451. }
  1452. llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
  1453. bool llama_kv_cache_unified_context::next() {
  1454. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1455. if (++i_cur >= ubatches.size()) {
  1456. return false;
  1457. }
  1458. return true;
  1459. }
  1460. bool llama_kv_cache_unified_context::apply() {
  1461. assert(!llama_memory_status_is_fail(status));
  1462. // no ubatches -> this is a KV cache update
  1463. if (ubatches.empty()) {
  1464. kv->update(lctx, do_shift, dinfo);
  1465. return true;
  1466. }
  1467. kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
  1468. n_kv = kv->get_n_kv();
  1469. return true;
  1470. }
  1471. llama_memory_status llama_kv_cache_unified_context::get_status() const {
  1472. return status;
  1473. }
  1474. const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
  1475. assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
  1476. return ubatches[i_cur];
  1477. }
  1478. uint32_t llama_kv_cache_unified_context::get_n_kv() const {
  1479. return n_kv;
  1480. }
  1481. ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
  1482. return kv->get_k(ctx, il, n_kv);
  1483. }
  1484. ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
  1485. return kv->get_v(ctx, il, n_kv);
  1486. }
  1487. ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
  1488. return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
  1489. }
  1490. ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
  1491. return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
  1492. }
  1493. ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1494. return kv->build_input_k_idxs(ctx, ubatch);
  1495. }
  1496. ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
  1497. return kv->build_input_v_idxs(ctx, ubatch);
  1498. }
  1499. void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
  1500. kv->set_input_k_shift(dst);
  1501. }
  1502. void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1503. kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
  1504. }
  1505. void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1506. kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
  1507. }
  1508. void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  1509. kv->set_input_kq_mask(dst, ubatch, causal_attn);
  1510. }
  1511. void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  1512. kv->set_input_pos_bucket(dst, ubatch);
  1513. }
  1514. uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
  1515. // the FA kernels require padding to avoid extra runtime boundary checks
  1516. return cparams.flash_attn ? 256u : 32u;
  1517. }