llama-kv-cache.cpp 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890
  1. #include "llama-kv-cache.h"
  2. #include "llama-impl.h"
  3. #include "llama-batch.h"
  4. #include "llama-cparams.h"
  5. #include "llama-model.h"
  6. #include "llama-context.h"
  7. #include <algorithm>
  8. #include <cassert>
  9. #include <cmath>
  10. #include <limits>
  11. #include <map>
  12. #include <stdexcept>
  13. //
  14. // llama_kv_cache_unified
  15. //
  16. uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
  17. // the FA kernels require padding to avoid extra runtime boundary checks
  18. return cparams.flash_attn ? 256u : 32u;
  19. }
  20. llama_kv_cache_unified::llama_kv_cache_unified(
  21. const llama_model & model,
  22. layer_filter_cb && filter,
  23. ggml_type type_k,
  24. ggml_type type_v,
  25. bool v_trans,
  26. bool offload,
  27. uint32_t kv_size,
  28. uint32_t padding,
  29. uint32_t n_swa,
  30. llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
  31. GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
  32. this->type_k = type_k;
  33. this->type_v = type_v;
  34. // create a context for each buffer type
  35. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  36. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  37. auto it = ctx_map.find(buft);
  38. if (it == ctx_map.end()) {
  39. ggml_init_params params = {
  40. /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
  41. /*.mem_buffer =*/ NULL,
  42. /*.no_alloc =*/ true,
  43. };
  44. ggml_context * ctx = ggml_init(params);
  45. if (!ctx) {
  46. return nullptr;
  47. }
  48. ctx_map[buft] = ctx;
  49. ctxs.emplace_back(ctx);
  50. return ctx;
  51. }
  52. return it->second;
  53. };
  54. head = 0;
  55. size = kv_size;
  56. used = 0;
  57. cells.resize(kv_size);
  58. for (uint32_t il = 0; il < hparams.n_layer; il++) {
  59. if (filter && !filter(il)) {
  60. LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
  61. continue;
  62. }
  63. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
  64. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  65. const char * dev_name = "CPU";
  66. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  67. if (offload) {
  68. auto * dev = model.dev_layer(il);
  69. buft = ggml_backend_dev_buffer_type(dev);
  70. dev_name = ggml_backend_dev_name(dev);
  71. }
  72. LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
  73. ggml_context * ctx = ctx_for_buft(buft);
  74. if (!ctx) {
  75. throw std::runtime_error("failed to create ggml context for kv cache");
  76. }
  77. ggml_tensor * k;
  78. ggml_tensor * v;
  79. k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
  80. v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
  81. ggml_format_name(k, "cache_k_l%d", il);
  82. ggml_format_name(v, "cache_v_l%d", il);
  83. map_layer_ids[il] = layers.size();
  84. layers.push_back({ il, k, v });
  85. }
  86. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  87. for (auto it : ctx_map) {
  88. auto * buft = it.first;
  89. auto * ctx = it.second;
  90. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  91. if (!buf) {
  92. throw std::runtime_error("failed to allocate buffer for kv cache");
  93. }
  94. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  95. ggml_backend_buffer_clear(buf, 0);
  96. bufs.emplace_back(buf);
  97. }
  98. {
  99. const size_t memory_size_k = size_k_bytes();
  100. const size_t memory_size_v = size_v_bytes();
  101. LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  102. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(),
  103. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  104. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  105. }
  106. }
  107. void llama_kv_cache_unified::clear() {
  108. for (uint32_t i = 0; i < size; ++i) {
  109. cells[i].pos = -1;
  110. cells[i].seq_id.clear();
  111. }
  112. head = 0;
  113. used = 0;
  114. for (auto & buf : bufs) {
  115. ggml_backend_buffer_clear(buf.get(), 0);
  116. }
  117. }
  118. bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  119. uint32_t new_head = size;
  120. if (p0 < 0) {
  121. p0 = 0;
  122. }
  123. if (p1 < 0) {
  124. p1 = std::numeric_limits<llama_pos>::max();
  125. }
  126. for (uint32_t i = 0; i < size; ++i) {
  127. if (cells[i].pos >= p0 && cells[i].pos < p1) {
  128. if (seq_id < 0) {
  129. cells[i].seq_id.clear();
  130. } else if (cells[i].has_seq_id(seq_id)) {
  131. cells[i].seq_id.erase(seq_id);
  132. } else {
  133. continue;
  134. }
  135. if (cells[i].is_empty()) {
  136. // keep count of the number of used cells
  137. if (cells[i].pos >= 0) {
  138. used--;
  139. }
  140. cells[i].pos = -1;
  141. if (new_head == size) {
  142. new_head = i;
  143. }
  144. }
  145. }
  146. }
  147. // If we freed up a slot, set head to it so searching can start there.
  148. if (new_head != size && new_head < head) {
  149. head = new_head;
  150. }
  151. return true;
  152. }
  153. void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  154. if (seq_id_src == seq_id_dst) {
  155. return;
  156. }
  157. if (p0 < 0) {
  158. p0 = 0;
  159. }
  160. if (p1 < 0) {
  161. p1 = std::numeric_limits<llama_pos>::max();
  162. }
  163. // otherwise, this is the KV of a Transformer-like model
  164. head = 0;
  165. for (uint32_t i = 0; i < size; ++i) {
  166. if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
  167. cells[i].seq_id.insert(seq_id_dst);
  168. }
  169. }
  170. }
  171. void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
  172. uint32_t new_head = size;
  173. for (uint32_t i = 0; i < size; ++i) {
  174. if (!cells[i].has_seq_id(seq_id)) {
  175. if (cells[i].pos >= 0) {
  176. used--;
  177. }
  178. cells[i].pos = -1;
  179. cells[i].seq_id.clear();
  180. if (new_head == size){
  181. new_head = i;
  182. }
  183. } else {
  184. cells[i].seq_id.clear();
  185. cells[i].seq_id.insert(seq_id);
  186. }
  187. }
  188. // If we freed up a slot, set head to it so searching can start there.
  189. if (new_head != size && new_head < head) {
  190. head = new_head;
  191. }
  192. }
  193. void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
  194. if (delta == 0) {
  195. return;
  196. }
  197. uint32_t new_head = size;
  198. if (p0 < 0) {
  199. p0 = 0;
  200. }
  201. if (p1 < 0) {
  202. p1 = std::numeric_limits<llama_pos>::max();
  203. }
  204. // If there is no range then return early to avoid looping over the
  205. if (p0 == p1) {
  206. return;
  207. }
  208. for (uint32_t i = 0; i < size; ++i) {
  209. if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
  210. has_shift = true;
  211. cells[i].pos += delta;
  212. cells[i].delta += delta;
  213. if (cells[i].pos < 0) {
  214. if (!cells[i].is_empty()) {
  215. used--;
  216. }
  217. cells[i].pos = -1;
  218. cells[i].seq_id.clear();
  219. if (new_head == size) {
  220. new_head = i;
  221. }
  222. }
  223. }
  224. }
  225. // If we freed up a slot, set head to it so searching can start there.
  226. // Otherwise we just start the next search from the beginning.
  227. head = new_head != size ? new_head : 0;
  228. }
  229. void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  230. if (d == 1) {
  231. return;
  232. }
  233. if (p0 < 0) {
  234. p0 = 0;
  235. }
  236. if (p1 < 0) {
  237. p1 = std::numeric_limits<llama_pos>::max();
  238. }
  239. // If there is no range then return early to avoid looping over the cache.
  240. if (p0 == p1) {
  241. return;
  242. }
  243. for (uint32_t i = 0; i < size; ++i) {
  244. if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
  245. has_shift = true;
  246. {
  247. llama_pos p_old = cells[i].pos;
  248. cells[i].pos /= d;
  249. cells[i].delta += cells[i].pos - p_old;
  250. }
  251. }
  252. }
  253. }
  254. llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
  255. llama_pos result = std::numeric_limits<llama_pos>::max();
  256. for (uint32_t i = 0; i < size; ++i) {
  257. if (cells[i].has_seq_id(seq_id)) {
  258. result = std::min(result, cells[i].pos);
  259. }
  260. }
  261. if (result == std::numeric_limits<llama_pos>::max()) {
  262. result = -1;
  263. }
  264. return result;
  265. }
  266. llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
  267. llama_pos result = -1;
  268. for (uint32_t i = 0; i < size; ++i) {
  269. if (cells[i].has_seq_id(seq_id)) {
  270. result = std::max(result, cells[i].pos);
  271. }
  272. }
  273. return result;
  274. }
  275. void llama_kv_cache_unified::restore() {
  276. for (const auto & [id, cell] : recovery.cells) {
  277. // TODO: move to new `struct kv_cells`
  278. const bool is_empty0 = cells[id].is_empty();
  279. const bool is_empty1 = cell.is_empty();
  280. if (!is_empty0 && is_empty1) {
  281. used--;
  282. } else if (is_empty0 && !is_empty1) {
  283. used++;
  284. }
  285. cells[id] = cell;
  286. }
  287. recovery.clear();
  288. }
  289. void llama_kv_cache_unified::commit() {
  290. if (recovery.cells.empty()) {
  291. LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
  292. __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
  293. return;
  294. }
  295. recovery.clear();
  296. }
  297. bool llama_kv_cache_unified::update(llama_context & lctx) {
  298. bool need_reserve = false;
  299. auto * sched = lctx.get_sched();
  300. if (has_shift) {
  301. if (!get_can_shift()) {
  302. GGML_ABORT("The current KV cache / model configuration does not support K-shift");
  303. }
  304. LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
  305. // apply K-shift if needed
  306. if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
  307. ggml_backend_sched_reset(sched);
  308. auto * gf = lctx.graph_init();
  309. auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
  310. ggml_backend_sched_alloc_graph(sched, gf);
  311. res->set_inputs(nullptr);
  312. lctx.graph_compute(gf, false);
  313. need_reserve = true;
  314. }
  315. {
  316. has_shift = false;
  317. for (uint32_t i = 0; i < size; ++i) {
  318. cells[i].delta = 0;
  319. }
  320. }
  321. }
  322. if (do_defrag) {
  323. LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
  324. if (defrag_prepare(lctx.graph_max_nodes())) {
  325. ggml_backend_sched_reset(sched);
  326. auto * gf = lctx.graph_init();
  327. auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
  328. ggml_backend_sched_alloc_graph(sched, gf);
  329. res->set_inputs(nullptr);
  330. lctx.graph_compute(gf, false);
  331. need_reserve = true;
  332. }
  333. do_defrag = false;
  334. }
  335. return need_reserve;
  336. }
  337. void llama_kv_cache_unified::defrag_sched(float thold) {
  338. // - do not defrag small contexts (i.e. < 2048 tokens)
  339. // - count the padding towards the number of used tokens
  340. const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
  341. // queue defragmentation for next llama_kv_cache_update
  342. if (fragmentation > thold) {
  343. LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
  344. do_defrag = true;
  345. }
  346. }
  347. void llama_kv_cache_unified::set_full() {
  348. n = size;
  349. // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
  350. // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
  351. // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
  352. // setting it to 0 is the simplest way to achieve that
  353. // ref: https://github.com/ggml-org/llama.cpp/issues/13359
  354. head = 0;
  355. }
  356. llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
  357. return llama_sbatch(batch, hparams.n_embd, true, logits_all);
  358. }
  359. llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
  360. GGML_UNUSED(embd_pooled);
  361. return sbatch.split_simple(n_ubatch);
  362. }
  363. bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
  364. const uint32_t n_tokens = ubatch.n_tokens;
  365. // if we have enough unused cells before the current head ->
  366. // better to start searching from the beginning of the cache, hoping to fill it
  367. if (head > used + 2*ubatch.n_tokens) {
  368. head = 0;
  369. }
  370. // otherwise, one cell per token.
  371. if (n_tokens > size) {
  372. LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
  373. return false;
  374. }
  375. //#define FIND_SLOT_DEBUG 1
  376. #if FIND_SLOT_DEBUG
  377. LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
  378. // for debugging
  379. {
  380. std::string ss;
  381. if (n_swa > 0) {
  382. for (uint32_t i = 0; i < size; ++i) {
  383. if (cells[i].pos == -1) {
  384. ss += '.';
  385. } else {
  386. ss += std::to_string(*cells[i].seq_id.begin());
  387. }
  388. if (i%256 == 255) {
  389. ss += '\n';
  390. }
  391. }
  392. }
  393. LLAMA_LOG_WARN("\n%s\n", ss.c_str());
  394. }
  395. #endif
  396. uint32_t n_tested = 0;
  397. while (true) {
  398. if (head + n_tokens > size) {
  399. n_tested += size - head;
  400. head = 0;
  401. continue;
  402. }
  403. bool found = true;
  404. for (uint32_t i = 0; i < n_tokens; i++) {
  405. if (cells[head + i].pos >= 0) {
  406. found = false;
  407. head += i + 1;
  408. n_tested += i + 1;
  409. break;
  410. }
  411. }
  412. if (found) {
  413. break;
  414. }
  415. if (n_tested >= size) {
  416. //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
  417. return false;
  418. }
  419. }
  420. for (uint32_t i = 0; i < n_tokens; ++i) {
  421. // remember the original state
  422. if (recovery.cells.find(head + i) == recovery.cells.end()) {
  423. recovery.cells[head + i] = cells[head + i];
  424. }
  425. cells[head + i].pos = ubatch.pos[i];
  426. for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
  427. cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
  428. }
  429. }
  430. used += n_tokens;
  431. // a heuristic, to avoid attending the full cache if it is not yet utilized
  432. // after enough generations, the benefit from this heuristic disappears
  433. // if we start defragmenting the cache, the benefit from this will be more important
  434. n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
  435. #ifdef FIND_SLOT_DEBUG
  436. LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
  437. #endif
  438. return true;
  439. }
  440. int32_t llama_kv_cache_unified::get_n_tokens() const {
  441. int32_t result = 0;
  442. for (uint32_t i = 0; i < size; i++) {
  443. result += cells[i].seq_id.size();
  444. }
  445. return result;
  446. }
  447. int32_t llama_kv_cache_unified::get_used_cells() const {
  448. return used;
  449. }
  450. bool llama_kv_cache_unified::get_can_shift() const {
  451. return true;
  452. }
  453. uint32_t llama_kv_cache_unified::get_n() const {
  454. return n;
  455. }
  456. uint32_t llama_kv_cache_unified::get_size() const {
  457. return size;
  458. }
  459. ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
  460. const int32_t ikv = map_layer_ids.at(il);
  461. auto * k = layers[ikv].k;
  462. return ggml_view_3d(ctx, k,
  463. hparams.n_embd_head_k, hparams.n_head_kv(il), n,
  464. ggml_row_size(k->type, hparams.n_embd_head_k),
  465. ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
  466. 0);
  467. }
  468. ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
  469. const int32_t ikv = map_layer_ids.at(il);
  470. auto * v = layers[ikv].v;
  471. if (!v_trans) {
  472. // note: v->nb[1] <= v->nb[2]
  473. return ggml_view_3d(ctx, v,
  474. hparams.n_embd_head_v, hparams.n_head_kv(il), n,
  475. ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
  476. ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
  477. 0);
  478. }
  479. // note: v->nb[1] > v->nb[2]
  480. return ggml_view_3d(ctx, v,
  481. n, hparams.n_head_kv(il), hparams.n_embd_head_v,
  482. ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
  483. ggml_row_size(v->type, v->ne[1]), // v->nb[2]
  484. 0);
  485. }
  486. ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
  487. const int32_t ikv = map_layer_ids.at(il);
  488. auto * k = layers[ikv].k;
  489. const int64_t n_tokens = k_cur->ne[2];
  490. ggml_tensor * k_view = ggml_view_1d(ctx, k,
  491. n_tokens*hparams.n_embd_k_gqa(il),
  492. ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
  493. return ggml_cpy(ctx, k_cur, k_view);
  494. }
  495. ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
  496. const int32_t ikv = map_layer_ids.at(il);
  497. auto * v = layers[ikv].v;
  498. const int64_t n_tokens = v_cur->ne[2];
  499. v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
  500. ggml_tensor * v_view = nullptr;
  501. if (!v_trans) {
  502. v_view = ggml_view_1d(ctx, v,
  503. n_tokens*hparams.n_embd_v_gqa(il),
  504. ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
  505. } else {
  506. // note: the V cache is transposed when not using flash attention
  507. v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
  508. (v->ne[1])*ggml_element_size(v),
  509. ( head)*ggml_element_size(v));
  510. v_cur = ggml_transpose(ctx, v_cur);
  511. }
  512. return ggml_cpy(ctx, v_cur, v_view);
  513. }
  514. void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
  515. // no pruning is needed when the cache does not use SWA
  516. GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
  517. int n_attended = 0;
  518. for (uint32_t i = 0; i < size; ++i) {
  519. const llama_pos p0 = cells[i].pos;
  520. if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
  521. n_attended++;
  522. }
  523. if (is_masked_swa(p0, pmax)) {
  524. if (seq_id < 0) {
  525. cells[i].seq_id.clear();
  526. } else if (cells[i].has_seq_id(seq_id)) {
  527. cells[i].seq_id.erase(seq_id);
  528. } else {
  529. continue;
  530. }
  531. if (cells[i].is_empty()) {
  532. // keep count of the number of used cells
  533. if (cells[i].pos >= 0) {
  534. used--;
  535. }
  536. cells[i].pos = -1;
  537. }
  538. }
  539. }
  540. if (n_attended < std::min<int>(n_swa, pmin)) {
  541. LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
  542. }
  543. }
  544. void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
  545. const int64_t n_tokens = ubatch->n_tokens;
  546. const int64_t n_seq_tokens = ubatch->n_seq_tokens;
  547. const int64_t n_seqs = ubatch->n_seqs;
  548. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  549. float * data = (float *) dst->data;
  550. const int64_t n_kv = n;
  551. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
  552. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
  553. // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
  554. // Causal mask:
  555. // xxx-------
  556. // xxxx------
  557. // xxxxx-----
  558. // Non-causal mask:
  559. // xxxxx-----
  560. // xxxxx-----
  561. // xxxxx-----
  562. // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
  563. for (int h = 0; h < 1; ++h) {
  564. for (int s = 0; s < n_seqs; ++s) {
  565. const llama_seq_id seq_id = ubatch->seq_id[s][0];
  566. for (int j = 0; j < n_seq_tokens; ++j) {
  567. const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
  568. for (int i = 0; i < n_kv; ++i) {
  569. const llama_pos p0 = cells[i].pos;
  570. bool masked = false;
  571. // mask the token if not the same sequence
  572. masked = masked || (!cells[i].has_seq_id(seq_id));
  573. // mask future tokens
  574. masked = masked || (causal_attn && p0 > p1);
  575. // apply SWA if any
  576. masked = masked || (is_masked_swa(p0, p1));
  577. float f = 0.0f;
  578. if (masked) {
  579. f = -INFINITY;
  580. } else if (hparams.use_alibi) {
  581. f = -std::abs(p0 - p1);
  582. }
  583. data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
  584. }
  585. }
  586. }
  587. // mask padded tokens
  588. if (data) {
  589. for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
  590. for (int j = 0; j < n_kv; ++j) {
  591. data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
  592. }
  593. }
  594. }
  595. }
  596. }
  597. void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
  598. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  599. int32_t * data = (int32_t *) dst->data;
  600. for (uint32_t i = 0; i < size; ++i) {
  601. data[i] = cells[i].delta;
  602. }
  603. }
  604. void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
  605. const int64_t n_tokens = ubatch->n_tokens;
  606. GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
  607. GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
  608. int32_t * data = (int32_t *) dst->data;
  609. const int64_t n_kv = n;
  610. for (int h = 0; h < 1; ++h) {
  611. for (int j = 0; j < n_tokens; ++j) {
  612. for (int i = 0; i < n_kv; ++i) {
  613. data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
  614. }
  615. }
  616. }
  617. }
  618. llama_pos llama_kv_cache_unified::get_pos_max() const {
  619. llama_pos pos_max = -1;
  620. for (const auto & cell : cells) {
  621. pos_max = std::max(pos_max, cell.pos);
  622. }
  623. return pos_max;
  624. }
  625. size_t llama_kv_cache_unified::total_size() const {
  626. size_t size = 0;
  627. for (const auto & buf : bufs) {
  628. size += ggml_backend_buffer_get_size(buf.get());
  629. }
  630. return size;
  631. }
  632. size_t llama_kv_cache_unified::size_k_bytes() const {
  633. size_t size_k_bytes = 0;
  634. for (const auto & layer : layers) {
  635. size_k_bytes += ggml_nbytes(layer.k);
  636. }
  637. return size_k_bytes;
  638. }
  639. size_t llama_kv_cache_unified::size_v_bytes() const {
  640. size_t size_v_bytes = 0;
  641. for (const auto & layer : layers) {
  642. size_v_bytes += ggml_nbytes(layer.v);
  643. }
  644. return size_v_bytes;
  645. }
  646. ggml_tensor * llama_kv_cache_unified::build_rope_shift(
  647. const llama_cparams & cparams,
  648. ggml_context * ctx,
  649. ggml_tensor * cur,
  650. ggml_tensor * shift,
  651. ggml_tensor * factors,
  652. float freq_base,
  653. float freq_scale) const {
  654. const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
  655. const auto & yarn_ext_factor = cparams.yarn_ext_factor;
  656. const auto & yarn_beta_fast = cparams.yarn_beta_fast;
  657. const auto & yarn_beta_slow = cparams.yarn_beta_slow;
  658. const auto & n_rot = hparams.n_rot;
  659. const auto & rope_type = hparams.rope_type;
  660. // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
  661. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
  662. const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
  663. ggml_tensor * tmp;
  664. if (ggml_is_quantized(cur->type)) {
  665. // dequantize to f32 -> RoPE -> quantize back
  666. tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
  667. tmp = ggml_rope_ext(ctx, tmp,
  668. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  669. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  670. tmp = ggml_cpy(ctx, tmp, cur);
  671. } else {
  672. // we rotate only the first n_rot dimensions
  673. tmp = ggml_rope_ext_inplace(ctx, cur,
  674. shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
  675. yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
  676. }
  677. return tmp;
  678. }
  679. class llm_graph_input_k_shift : public llm_graph_input_i {
  680. public:
  681. llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
  682. virtual ~llm_graph_input_k_shift() = default;
  683. void set_input(const llama_ubatch * ubatch) override;
  684. ggml_tensor * k_shift; // I32 [kv_size]
  685. const llama_kv_cache_unified * kv_self;
  686. };
  687. void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
  688. GGML_UNUSED(ubatch);
  689. if (k_shift) {
  690. kv_self->set_input_k_shift(k_shift);
  691. }
  692. }
  693. llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
  694. const llama_cparams & cparams,
  695. ggml_context * ctx,
  696. ggml_cgraph * gf) const {
  697. auto res = std::make_unique<llm_graph_result>();
  698. const auto & n_embd_head_k = hparams.n_embd_head_k;
  699. //const auto & n_embd_head_v = hparams.n_embd_head_v;
  700. //GGML_ASSERT(kv_self->size == n_ctx);
  701. auto inp = std::make_unique<llm_graph_input_k_shift>(this);
  702. inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
  703. ggml_set_input(inp->k_shift);
  704. for (const auto & layer : layers) {
  705. const uint32_t il = layer.il;
  706. const int64_t n_head_kv = hparams.n_head_kv(il);
  707. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  708. const float freq_base_l = model.get_rope_freq_base (cparams, il);
  709. const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
  710. ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
  711. ggml_tensor * k =
  712. ggml_view_3d(ctx, layer.k,
  713. n_embd_head_k, n_head_kv, size,
  714. ggml_row_size(layer.k->type, n_embd_head_k),
  715. ggml_row_size(layer.k->type, n_embd_k_gqa),
  716. 0);
  717. ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
  718. ggml_build_forward_expand(gf, cur);
  719. }
  720. res->add_input(std::move(inp));
  721. return res;
  722. }
  723. llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
  724. const llama_cparams & cparams,
  725. ggml_context * ctx,
  726. ggml_cgraph * gf) const {
  727. auto res = std::make_unique<llm_graph_result>();
  728. const auto & ids = defrag_info.ids;
  729. #if 0
  730. // CPU defrag
  731. //
  732. // TODO: optimizations are possible:
  733. // - multiple threads
  734. // - avoid copying to the host memory when already there
  735. //
  736. // likely not worth the effort, as we have ggml_graph based defrag
  737. //
  738. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
  739. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
  740. const uint32_t kv_size = size;
  741. std::vector<uint8_t> buf_k;
  742. std::vector<uint8_t> buf_v;
  743. for (uint32_t il = 0; il < n_layer; ++il) {
  744. const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
  745. const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
  746. const size_t v_size_el = ggml_type_size(v_l[il]->type);
  747. const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
  748. buf_k.resize(k_size);
  749. buf_v.resize(v_size);
  750. ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
  751. ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
  752. // batch move [i, i+nm) to [id, id+nm)
  753. // note: cells can move only to a lower index
  754. for (uint32_t i = 0; i < n_kv; ++i) {
  755. const uint32_t id = ids[i];
  756. if (i == id || id == n_kv) {
  757. continue;
  758. }
  759. uint32_t nm = 1;
  760. while (i + nm < n_kv && ids[i + nm] == id + nm) {
  761. nm++;
  762. }
  763. // move keys
  764. {
  765. const int64_t os = i*k_size_row;
  766. const int64_t od = id*k_size_row;
  767. memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
  768. }
  769. // move values (note: they are transposed)
  770. {
  771. const int64_t os = i;
  772. const int64_t od = id;
  773. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  774. memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
  775. }
  776. }
  777. i += nm - 1;
  778. }
  779. ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
  780. ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
  781. }
  782. #else
  783. for (uint32_t i = 0; i < ids.size(); ++i) {
  784. const uint32_t id = ids[i];
  785. if (i == id || id == ids.size()) {
  786. continue;
  787. }
  788. uint32_t nm = 1;
  789. while (i + nm < ids.size() && ids[i + nm] == id + nm) {
  790. nm++;
  791. }
  792. for (const auto & layer : layers) {
  793. const uint32_t il = layer.il;
  794. const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
  795. const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
  796. ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
  797. n_embd_k_gqa, nm,
  798. ggml_row_size(layer.k->type, n_embd_k_gqa),
  799. ggml_row_size(layer.k->type, n_embd_k_gqa*i));
  800. ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
  801. n_embd_k_gqa, nm,
  802. ggml_row_size(layer.k->type, n_embd_k_gqa),
  803. ggml_row_size(layer.k->type, n_embd_k_gqa*id));
  804. ggml_tensor * view_v_src;
  805. ggml_tensor * view_v_dst;
  806. if (cparams.flash_attn) {
  807. // NOTE: the V cache is not transposed when using flash attention
  808. view_v_src = ggml_view_2d(ctx, layer.v,
  809. n_embd_v_gqa, nm,
  810. ggml_row_size(layer.v->type, n_embd_v_gqa),
  811. ggml_row_size(layer.v->type, n_embd_v_gqa*i));
  812. view_v_dst = ggml_view_2d(ctx, layer.v,
  813. n_embd_v_gqa, nm,
  814. ggml_row_size(layer.v->type, n_embd_v_gqa),
  815. ggml_row_size(layer.v->type, n_embd_v_gqa*id));
  816. } else {
  817. view_v_src = ggml_view_2d(ctx, layer.v,
  818. nm, n_embd_v_gqa,
  819. ggml_row_size(layer.v->type, size),
  820. ggml_row_size(layer.v->type, i));
  821. view_v_dst = ggml_view_2d(ctx, layer.v,
  822. nm, n_embd_v_gqa,
  823. ggml_row_size(layer.v->type, size),
  824. ggml_row_size(layer.v->type, id));
  825. }
  826. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
  827. ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
  828. }
  829. i += nm - 1;
  830. }
  831. //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
  832. #endif
  833. return res;
  834. }
  835. bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
  836. const uint32_t n_layer = layers.size();
  837. const uint32_t n_kv = cell_max();
  838. const uint32_t n_used = used;
  839. assert(n_used <= n_kv);
  840. //const int64_t t_start = ggml_time_us();
  841. // number of cells moved
  842. uint32_t n_moves = 0;
  843. // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
  844. // - source view, destination view, copy operation
  845. // - x2 for keys and values
  846. //const uint32_t max_moves = max_nodes()/(6*n_layer);
  847. // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
  848. const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
  849. // determine which KV cells to move where
  850. //
  851. // cell i moves to ids[i]
  852. //
  853. // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
  854. //
  855. auto & ids = defrag_info.ids;
  856. ids.clear();
  857. ids.resize(n_kv, n_kv);
  858. for (uint32_t i0 = 0; i0 < n_used; ++i0) {
  859. const auto & cell0 = cells[i0];
  860. if (!cell0.is_empty()) {
  861. ids[i0] = i0;
  862. continue;
  863. }
  864. // found a hole - fill it with data from the end of the cache
  865. uint32_t nh = 1;
  866. // determine the size of the hole
  867. while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
  868. nh++;
  869. }
  870. uint32_t nf = 0;
  871. uint32_t is = n_kv - 1;
  872. // starting from the end, find nh non-empty cells
  873. for (; is > i0; --is) {
  874. const auto & cell1 = cells[is];
  875. if (cell1.is_empty() || ids[is] != n_kv) {
  876. continue;
  877. }
  878. // non-empty cell which is not yet moved
  879. nf++;
  880. if (nf == nh) {
  881. break;
  882. }
  883. }
  884. // this can only happen if `n_used` is not accurate, which would be a bug
  885. GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
  886. nf = 0;
  887. uint32_t i1 = is;
  888. // are we moving a continuous block of memory?
  889. bool cont = false;
  890. // should we stop searching for the next move?
  891. bool stop = false;
  892. // go back and move the nf cells to the hole
  893. for (; i1 < n_kv; ++i1) {
  894. auto & cell1 = cells[i1];
  895. if (cell1.is_empty() || ids[i1] != n_kv) {
  896. if (n_moves == max_moves) {
  897. stop = true;
  898. break;
  899. }
  900. cont = false;
  901. continue;
  902. }
  903. // this cell goes to (i0 + nf)
  904. ids[i1] = i0 + nf;
  905. // move the cell meta data
  906. cells[i0 + nf] = cell1;
  907. // clear the old cell and move the head there
  908. cell1 = kv_cell();
  909. head = n_used;
  910. if (!cont) {
  911. n_moves++;
  912. cont = true;
  913. }
  914. nf++;
  915. if (nf == nh) {
  916. break;
  917. }
  918. }
  919. if (stop || n_moves == max_moves) {
  920. break;
  921. }
  922. //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
  923. i0 += nh - 1;
  924. }
  925. if (n_moves == 0) {
  926. return false;
  927. }
  928. LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
  929. LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
  930. return true;
  931. }
  932. uint32_t llama_kv_cache_unified::cell_max() const {
  933. for (uint32_t i = size; i > 0; --i) {
  934. const kv_cell & cell = cells[i - 1];
  935. if (cell.pos >= 0 && !cell.is_empty()) {
  936. return i;
  937. }
  938. }
  939. return 0;
  940. }
  941. bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
  942. if (p0 < 0) {
  943. return true;
  944. }
  945. switch (swa_type) {
  946. case LLAMA_SWA_TYPE_NONE:
  947. {
  948. } break;
  949. case LLAMA_SWA_TYPE_STANDARD:
  950. {
  951. if (p1 - p0 >= (int32_t) n_swa) {
  952. return true;
  953. }
  954. } break;
  955. case LLAMA_SWA_TYPE_CHUNKED:
  956. {
  957. const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
  958. if (p0 < pos_chunk_start) {
  959. return true;
  960. }
  961. } break;
  962. }
  963. return false;
  964. }
  965. void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
  966. std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
  967. uint32_t cell_count = 0;
  968. // Count the number of cells with the specified seq_id
  969. // Find all the ranges of cells with this seq id (or all, when -1)
  970. uint32_t cell_range_begin = size;
  971. for (uint32_t i = 0; i < size; ++i) {
  972. const auto & cell = cells[i];
  973. if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
  974. ++cell_count;
  975. if (cell_range_begin == size) {
  976. cell_range_begin = i;
  977. }
  978. } else {
  979. if (cell_range_begin != size) {
  980. cell_ranges.emplace_back(cell_range_begin, i);
  981. cell_range_begin = size;
  982. }
  983. }
  984. }
  985. if (cell_range_begin != size) {
  986. cell_ranges.emplace_back(cell_range_begin, size);
  987. }
  988. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  989. uint32_t cell_count_check = 0;
  990. for (const auto & range : cell_ranges) {
  991. cell_count_check += range.second - range.first;
  992. }
  993. GGML_ASSERT(cell_count == cell_count_check);
  994. io.write(&cell_count, sizeof(cell_count));
  995. state_write_meta(io, cell_ranges, seq_id);
  996. state_write_data(io, cell_ranges);
  997. }
  998. void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
  999. uint32_t cell_count;
  1000. io.read_to(&cell_count, sizeof(cell_count));
  1001. bool res = true;
  1002. res = res && state_read_meta(io, cell_count, seq_id);
  1003. res = res && state_read_data(io, cell_count);
  1004. if (!res) {
  1005. if (seq_id == -1) {
  1006. clear();
  1007. } else {
  1008. seq_rm(seq_id, -1, -1);
  1009. }
  1010. throw std::runtime_error("failed to restore kv cache");
  1011. }
  1012. }
  1013. void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
  1014. for (const auto & range : cell_ranges) {
  1015. for (uint32_t i = range.first; i < range.second; ++i) {
  1016. const auto & cell = cells[i];
  1017. const llama_pos pos = cell.pos;
  1018. const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
  1019. io.write(&pos, sizeof(pos));
  1020. io.write(&n_seq_id, sizeof(n_seq_id));
  1021. if (n_seq_id) {
  1022. for (auto seq_id : cell.seq_id) {
  1023. io.write(&seq_id, sizeof(seq_id));
  1024. }
  1025. }
  1026. }
  1027. }
  1028. }
  1029. void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
  1030. const uint32_t v_trans = this->v_trans ? 1 : 0;
  1031. const uint32_t n_layer = layers.size();
  1032. io.write(&v_trans, sizeof(v_trans));
  1033. io.write(&n_layer, sizeof(n_layer));
  1034. std::vector<uint8_t> tmp_buf;
  1035. // Iterate and write all the keys first, each row is a cell
  1036. // Get whole range at a time
  1037. for (const auto & layer : layers) {
  1038. const uint32_t il = layer.il;
  1039. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
  1040. // Write key type
  1041. const int32_t k_type_i = (int32_t)layer.k->type;
  1042. io.write(&k_type_i, sizeof(k_type_i));
  1043. // Write row size of key
  1044. const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1045. io.write(&k_size_row, sizeof(k_size_row));
  1046. // Read each range of cells of k_size length each into tmp_buf and write out
  1047. for (const auto & range : cell_ranges) {
  1048. const size_t range_size = range.second - range.first;
  1049. const size_t buf_size = range_size * k_size_row;
  1050. io.write_tensor(layer.k, range.first * k_size_row, buf_size);
  1051. }
  1052. }
  1053. if (!v_trans) {
  1054. for (const auto & layer : layers) {
  1055. const uint32_t il = layer.il;
  1056. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  1057. // Write value type
  1058. const int32_t v_type_i = (int32_t)layer.v->type;
  1059. io.write(&v_type_i, sizeof(v_type_i));
  1060. // Write row size of value
  1061. const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1062. io.write(&v_size_row, sizeof(v_size_row));
  1063. // Read each range of cells of v_size length each into tmp_buf and write out
  1064. for (const auto & range : cell_ranges) {
  1065. const size_t range_size = range.second - range.first;
  1066. const size_t buf_size = range_size * v_size_row;
  1067. io.write_tensor(layer.v, range.first * v_size_row, buf_size);
  1068. }
  1069. }
  1070. } else {
  1071. // When v is transposed, we also need the element size and get the element ranges from each row
  1072. const uint32_t kv_size = size;
  1073. for (const auto & layer : layers) {
  1074. const uint32_t il = layer.il;
  1075. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  1076. // Write value type
  1077. const int32_t v_type_i = (int32_t)layer.v->type;
  1078. io.write(&v_type_i, sizeof(v_type_i));
  1079. // Write element size
  1080. const uint32_t v_size_el = ggml_type_size(layer.v->type);
  1081. io.write(&v_size_el, sizeof(v_size_el));
  1082. // Write GQA embedding size
  1083. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  1084. // For each row, we get the element values of each cell
  1085. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1086. // Read each range of cells of v_size_el length each into tmp_buf and write out
  1087. for (const auto & range : cell_ranges) {
  1088. const size_t range_size = range.second - range.first;
  1089. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  1090. const size_t buf_size = range_size * v_size_el;
  1091. io.write_tensor(layer.v, src_offset, buf_size);
  1092. }
  1093. }
  1094. }
  1095. }
  1096. }
  1097. bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
  1098. if (dest_seq_id != -1) {
  1099. // single sequence
  1100. seq_rm(dest_seq_id, -1, -1);
  1101. llama_sbatch sbatch;
  1102. llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
  1103. batch.n_tokens = cell_count;
  1104. for (uint32_t i = 0; i < cell_count; ++i) {
  1105. llama_pos pos;
  1106. uint32_t n_seq_id;
  1107. io.read_to(&pos, sizeof(pos));
  1108. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1109. if (n_seq_id != 0) {
  1110. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  1111. return false;
  1112. }
  1113. batch.pos[i] = pos;
  1114. batch.n_seq_id[i] = 1;
  1115. batch.seq_id[i] = &dest_seq_id;
  1116. }
  1117. if (!find_slot(batch)) {
  1118. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  1119. return false;
  1120. }
  1121. commit();
  1122. // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
  1123. // Assume that this is one contiguous block of cells
  1124. GGML_ASSERT(head + cell_count <= size);
  1125. GGML_ASSERT(cells[head].pos == batch.pos[0]);
  1126. GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
  1127. GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
  1128. GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
  1129. } else {
  1130. // whole KV cache restore
  1131. if (cell_count > size) {
  1132. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  1133. return false;
  1134. }
  1135. clear();
  1136. for (uint32_t i = 0; i < cell_count; ++i) {
  1137. kv_cell & cell = cells[i];
  1138. llama_pos pos;
  1139. uint32_t n_seq_id;
  1140. io.read_to(&pos, sizeof(pos));
  1141. io.read_to(&n_seq_id, sizeof(n_seq_id));
  1142. cell.pos = pos;
  1143. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1144. llama_seq_id seq_id;
  1145. io.read_to(&seq_id, sizeof(seq_id));
  1146. // TODO: llama_kv_cache_unified should have a notion of max sequences
  1147. //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
  1148. if (seq_id < 0) {
  1149. //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
  1150. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
  1151. return false;
  1152. }
  1153. cell.seq_id.insert(seq_id);
  1154. }
  1155. }
  1156. head = 0;
  1157. used = cell_count;
  1158. }
  1159. return true;
  1160. }
  1161. bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
  1162. uint32_t v_trans;
  1163. uint32_t n_layer;
  1164. io.read_to(&v_trans, sizeof(v_trans));
  1165. io.read_to(&n_layer, sizeof(n_layer));
  1166. if (n_layer != layers.size()) {
  1167. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
  1168. return false;
  1169. }
  1170. if (cell_count > size) {
  1171. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
  1172. return false;
  1173. }
  1174. if (this->v_trans != (bool) v_trans) {
  1175. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  1176. return false;
  1177. }
  1178. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  1179. for (const auto & layer : layers) {
  1180. const uint32_t il = layer.il;
  1181. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
  1182. // Read type of key
  1183. int32_t k_type_i_ref;
  1184. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  1185. const int32_t k_type_i = (int32_t) layer.k->type;
  1186. if (k_type_i != k_type_i_ref) {
  1187. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  1188. return false;
  1189. }
  1190. // Read row size of key
  1191. uint64_t k_size_row_ref;
  1192. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  1193. const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
  1194. if (k_size_row != k_size_row_ref) {
  1195. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  1196. return false;
  1197. }
  1198. if (cell_count) {
  1199. // Read and set the keys for the whole cell range
  1200. ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  1201. }
  1202. }
  1203. if (!this->v_trans) {
  1204. for (const auto & layer : layers) {
  1205. const uint32_t il = layer.il;
  1206. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  1207. // Read type of value
  1208. int32_t v_type_i_ref;
  1209. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1210. const int32_t v_type_i = (int32_t)layer.v->type;
  1211. if (v_type_i != v_type_i_ref) {
  1212. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1213. return false;
  1214. }
  1215. // Read row size of value
  1216. uint64_t v_size_row_ref;
  1217. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  1218. const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
  1219. if (v_size_row != v_size_row_ref) {
  1220. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  1221. return false;
  1222. }
  1223. if (cell_count) {
  1224. // Read and set the values for the whole cell range
  1225. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  1226. }
  1227. }
  1228. } else {
  1229. // For each layer, read the values for each cell (transposed)
  1230. for (const auto & layer : layers) {
  1231. const uint32_t il = layer.il;
  1232. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  1233. // Read type of value
  1234. int32_t v_type_i_ref;
  1235. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  1236. const int32_t v_type_i = (int32_t)layer.v->type;
  1237. if (v_type_i != v_type_i_ref) {
  1238. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  1239. return false;
  1240. }
  1241. // Read element size of value
  1242. uint32_t v_size_el_ref;
  1243. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  1244. const size_t v_size_el = ggml_type_size(layer.v->type);
  1245. if (v_size_el != v_size_el_ref) {
  1246. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  1247. return false;
  1248. }
  1249. // Read GQA embedding size
  1250. uint32_t n_embd_v_gqa_ref;
  1251. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  1252. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  1253. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  1254. return false;
  1255. }
  1256. if (cell_count) {
  1257. // For each row in the transposed matrix, read the values for the whole cell range
  1258. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  1259. const size_t dst_offset = (head + j * size) * v_size_el;
  1260. ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  1261. }
  1262. }
  1263. }
  1264. }
  1265. return true;
  1266. }
  1267. //
  1268. // llama_kv_cache_unified_iswa
  1269. //
  1270. llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
  1271. const llama_model & model,
  1272. ggml_type type_k,
  1273. ggml_type type_v,
  1274. bool v_trans,
  1275. bool offload,
  1276. uint32_t kv_size,
  1277. bool swa_full,
  1278. uint32_t n_seq_max,
  1279. uint32_t n_batch,
  1280. uint32_t padding) : hparams(model.hparams) {
  1281. llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
  1282. llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
  1283. const uint32_t size_base = kv_size;
  1284. uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
  1285. // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
  1286. if (swa_full) {
  1287. LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
  1288. __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
  1289. size_swa = size_base;
  1290. do_prune = false;
  1291. }
  1292. LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
  1293. kv_base = std::make_unique<llama_kv_cache_unified>(
  1294. model, std::move(filter_base), type_k, type_v,
  1295. v_trans, offload, size_base, padding,
  1296. 0, LLAMA_SWA_TYPE_NONE);
  1297. LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
  1298. kv_swa = std::make_unique<llama_kv_cache_unified>(
  1299. model, std::move(filter_swa), type_k, type_v,
  1300. v_trans, offload, size_swa, padding,
  1301. hparams.n_swa, hparams.swa_type);
  1302. }
  1303. void llama_kv_cache_unified_iswa::clear() {
  1304. kv_base->clear();
  1305. kv_swa ->clear();
  1306. }
  1307. bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  1308. bool res = true;
  1309. res = res & kv_base->seq_rm(seq_id, p0, p1);
  1310. res = res & kv_swa ->seq_rm(seq_id, p0, p1);
  1311. return res;
  1312. }
  1313. void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  1314. kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
  1315. kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
  1316. }
  1317. void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
  1318. kv_base->seq_keep(seq_id);
  1319. kv_swa ->seq_keep(seq_id);
  1320. }
  1321. void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
  1322. kv_base->seq_add(seq_id, p0, p1, delta);
  1323. kv_swa ->seq_add(seq_id, p0, p1, delta);
  1324. }
  1325. void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  1326. kv_base->seq_div(seq_id, p0, p1, d);
  1327. kv_swa ->seq_div(seq_id, p0, p1, d);
  1328. }
  1329. llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
  1330. // the base cache is a superset of the SWA cache, so we can just check the SWA cache
  1331. return kv_swa->seq_pos_min(seq_id);
  1332. }
  1333. llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
  1334. return kv_swa->seq_pos_max(seq_id);
  1335. }
  1336. void llama_kv_cache_unified_iswa::restore() {
  1337. kv_base->restore();
  1338. kv_swa ->restore();
  1339. }
  1340. void llama_kv_cache_unified_iswa::commit() {
  1341. kv_base->commit();
  1342. kv_swa ->commit();
  1343. // slide the attention window, forgetting/pruning old tokens that are outside the window
  1344. if (do_prune) {
  1345. for (const auto & [seq_id, entry] : pending.pos) {
  1346. kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
  1347. }
  1348. }
  1349. pending.clear();
  1350. }
  1351. bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
  1352. bool res = true;
  1353. res = res & kv_base->update(lctx);
  1354. res = res & kv_swa ->update(lctx);
  1355. return res;
  1356. }
  1357. void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
  1358. kv_base->defrag_sched(thold);
  1359. kv_swa ->defrag_sched(thold);
  1360. }
  1361. void llama_kv_cache_unified_iswa::set_full() {
  1362. kv_base->set_full();
  1363. kv_swa ->set_full();
  1364. }
  1365. llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
  1366. pending.clear();
  1367. if (do_prune) {
  1368. for (int i = 0; i < batch.n_tokens; ++i) {
  1369. for (int s = 0; s < batch.n_seq_id[i]; ++s) {
  1370. const llama_seq_id seq_id = batch.seq_id[i][s];
  1371. const llama_pos pos = batch.pos[i];
  1372. if (pending.pos.find(seq_id) == pending.pos.end()) {
  1373. pending.pos[seq_id].pmin = pos;
  1374. pending.pos[seq_id].pmax = pos;
  1375. } else {
  1376. pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
  1377. pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
  1378. }
  1379. }
  1380. }
  1381. }
  1382. return llama_sbatch(batch, hparams.n_embd, true, logits_all);
  1383. }
  1384. llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
  1385. GGML_UNUSED(embd_pooled);
  1386. return sbatch.split_simple(n_ubatch);
  1387. }
  1388. bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
  1389. bool res = true;
  1390. res = res & kv_base->find_slot(batch);
  1391. res = res & kv_swa ->find_slot(batch);
  1392. return res;
  1393. }
  1394. int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
  1395. return kv_base->get_n_tokens();
  1396. }
  1397. int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
  1398. return kv_base->get_used_cells();
  1399. }
  1400. llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
  1401. return kv_base->get_pos_max();
  1402. }
  1403. bool llama_kv_cache_unified_iswa::get_can_shift() const {
  1404. return kv_base->get_size() == kv_swa->get_size();
  1405. }
  1406. void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
  1407. kv_base->state_write(io, seq_id);
  1408. kv_swa ->state_write(io, seq_id);
  1409. }
  1410. void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
  1411. kv_base->state_read(io, seq_id);
  1412. kv_swa ->state_read(io, seq_id);
  1413. }
  1414. llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
  1415. return kv_base.get();
  1416. }
  1417. llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
  1418. return kv_swa.get();
  1419. }
  1420. //
  1421. // llama_kv_cache_recurrent
  1422. //
  1423. llama_kv_cache_recurrent::llama_kv_cache_recurrent(
  1424. const llama_model & model,
  1425. ggml_type type_k,
  1426. ggml_type type_v,
  1427. bool offload,
  1428. uint32_t kv_size) : hparams(model.hparams) {
  1429. const int32_t n_layer = hparams.n_layer;
  1430. LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
  1431. __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
  1432. head = 0;
  1433. size = kv_size;
  1434. used = 0;
  1435. this->type_k = type_k;
  1436. this->type_v = type_v;
  1437. cells.clear();
  1438. cells.resize(kv_size);
  1439. // create a context for each buffer type
  1440. std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
  1441. auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
  1442. auto it = ctx_map.find(buft);
  1443. if (it == ctx_map.end()) {
  1444. ggml_init_params params = {
  1445. /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
  1446. /*.mem_buffer =*/ NULL,
  1447. /*.no_alloc =*/ true,
  1448. };
  1449. ggml_context * ctx = ggml_init(params);
  1450. if (!ctx) {
  1451. return nullptr;
  1452. }
  1453. ctx_map[buft] = ctx;
  1454. ctxs.emplace_back(ctx);
  1455. return ctx;
  1456. }
  1457. return it->second;
  1458. };
  1459. k_l.reserve(n_layer);
  1460. v_l.reserve(n_layer);
  1461. for (int i = 0; i < n_layer; i++) {
  1462. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
  1463. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
  1464. const char * dev_name = "CPU";
  1465. ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
  1466. if (offload) {
  1467. auto * dev = model.dev_layer(i);
  1468. buft = ggml_backend_dev_buffer_type(dev);
  1469. dev_name = ggml_backend_dev_name(dev);
  1470. }
  1471. LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
  1472. ggml_context * ctx = ctx_for_buft(buft);
  1473. if (!ctx) {
  1474. throw std::runtime_error("failed to create ggml context for kv cache");
  1475. }
  1476. ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
  1477. ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
  1478. ggml_format_name(k, "cache_k_l%d", i);
  1479. ggml_format_name(v, "cache_v_l%d", i);
  1480. k_l.push_back(k);
  1481. v_l.push_back(v);
  1482. }
  1483. // allocate tensors and initialize the buffers to avoid NaNs in the padding
  1484. for (auto it : ctx_map) {
  1485. auto * buft = it.first;
  1486. auto * ctx = it.second;
  1487. ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
  1488. if (!buf) {
  1489. throw std::runtime_error("failed to allocate buffer for kv cache");
  1490. }
  1491. ggml_backend_buffer_clear(buf, 0);
  1492. LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
  1493. bufs.emplace_back(buf);
  1494. }
  1495. {
  1496. const size_t memory_size_k = size_k_bytes();
  1497. const size_t memory_size_v = size_v_bytes();
  1498. LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
  1499. (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
  1500. ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
  1501. ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
  1502. }
  1503. }
  1504. void llama_kv_cache_recurrent::clear() {
  1505. for (int32_t i = 0; i < (int32_t) size; ++i) {
  1506. cells[i].pos = -1;
  1507. cells[i].seq_id.clear();
  1508. cells[i].src = -1;
  1509. cells[i].tail = -1;
  1510. }
  1511. head = 0;
  1512. used = 0;
  1513. for (auto & buf : bufs) {
  1514. ggml_backend_buffer_clear(buf.get(), 0);
  1515. }
  1516. }
  1517. bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
  1518. uint32_t new_head = size;
  1519. if (p0 < 0) {
  1520. p0 = 0;
  1521. }
  1522. if (p1 < 0) {
  1523. p1 = std::numeric_limits<llama_pos>::max();
  1524. }
  1525. // models like Mamba or RWKV can't have a state partially erased
  1526. if (seq_id >= (int64_t) size) {
  1527. // could be fatal
  1528. return false;
  1529. }
  1530. if (0 <= seq_id) {
  1531. int32_t & tail_id = cells[seq_id].tail;
  1532. if (tail_id >= 0) {
  1533. const kv_cell & cell = cells[tail_id];
  1534. // partial intersection is invalid
  1535. if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
  1536. return false;
  1537. }
  1538. // invalidate tails which will be cleared
  1539. if (p0 <= cell.pos && cell.pos < p1) {
  1540. tail_id = -1;
  1541. }
  1542. }
  1543. } else {
  1544. // seq_id is negative, then the range should include everything or nothing
  1545. if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
  1546. return false;
  1547. }
  1548. }
  1549. for (uint32_t i = 0; i < size; ++i) {
  1550. if (cells[i].pos >= p0 && cells[i].pos < p1) {
  1551. if (seq_id < 0) {
  1552. cells[i].seq_id.clear();
  1553. } else if (cells[i].has_seq_id(seq_id)) {
  1554. cells[i].seq_id.erase(seq_id);
  1555. } else {
  1556. continue;
  1557. }
  1558. if (cells[i].is_empty()) {
  1559. // keep count of the number of used cells
  1560. if (cells[i].pos >= 0) {
  1561. used--;
  1562. }
  1563. cells[i].pos = -1;
  1564. cells[i].src = -1;
  1565. if (new_head == size) {
  1566. new_head = i;
  1567. }
  1568. }
  1569. }
  1570. }
  1571. // If we freed up a slot, set head to it so searching can start there.
  1572. if (new_head != size && new_head < head) {
  1573. head = new_head;
  1574. }
  1575. return true;
  1576. }
  1577. void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
  1578. if (seq_id_src == seq_id_dst) {
  1579. return;
  1580. }
  1581. if (p0 < 0) {
  1582. p0 = 0;
  1583. }
  1584. if (p1 < 0) {
  1585. p1 = std::numeric_limits<llama_pos>::max();
  1586. }
  1587. if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
  1588. kv_cell & tail_src = cells[seq_id_src];
  1589. kv_cell & tail_dst = cells[seq_id_dst];
  1590. if (tail_dst.tail >= 0) {
  1591. // clear destination seq_id if it wasn't empty
  1592. kv_cell & cell_dst = cells[tail_dst.tail];
  1593. cell_dst.seq_id.erase(seq_id_dst);
  1594. tail_dst.tail = -1;
  1595. if (cell_dst.seq_id.empty()) {
  1596. cell_dst.pos = -1;
  1597. cell_dst.src = -1;
  1598. used -= 1;
  1599. }
  1600. }
  1601. if (tail_src.tail >= 0) {
  1602. kv_cell & cell_src = cells[tail_src.tail];
  1603. cell_src.seq_id.insert(seq_id_dst);
  1604. tail_dst.tail = tail_src.tail;
  1605. }
  1606. }
  1607. }
  1608. void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
  1609. uint32_t new_head = size;
  1610. for (uint32_t i = 0; i < size; ++i) {
  1611. if ((llama_seq_id) i != seq_id) {
  1612. cells[i].tail = -1;
  1613. }
  1614. if (!cells[i].has_seq_id(seq_id)) {
  1615. if (cells[i].pos >= 0) {
  1616. used--;
  1617. }
  1618. cells[i].pos = -1;
  1619. cells[i].src = -1;
  1620. cells[i].seq_id.clear();
  1621. if (new_head == size){
  1622. new_head = i;
  1623. }
  1624. } else {
  1625. cells[i].seq_id.clear();
  1626. cells[i].seq_id.insert(seq_id);
  1627. }
  1628. }
  1629. // If we freed up a slot, set head to it so searching can start there.
  1630. if (new_head != size && new_head < head) {
  1631. head = new_head;
  1632. }
  1633. }
  1634. void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
  1635. if (delta == 0) {
  1636. return;
  1637. }
  1638. if (p0 < 0) {
  1639. p0 = 0;
  1640. }
  1641. if (p1 < 0) {
  1642. p1 = std::numeric_limits<llama_pos>::max();
  1643. }
  1644. // If there is no range then return early to avoid looping over the
  1645. if (p0 == p1) {
  1646. return;
  1647. }
  1648. // for Mamba-like or RWKV models, only the pos needs to be shifted
  1649. if (0 <= seq_id && seq_id < (int64_t) size) {
  1650. const int32_t tail_id = cells[seq_id].tail;
  1651. if (tail_id >= 0) {
  1652. kv_cell & cell = cells[tail_id];
  1653. if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
  1654. cell.pos += delta;
  1655. }
  1656. }
  1657. }
  1658. }
  1659. void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
  1660. if (d == 1) {
  1661. return;
  1662. }
  1663. if (p0 < 0) {
  1664. p0 = 0;
  1665. }
  1666. if (p1 < 0) {
  1667. p1 = std::numeric_limits<llama_pos>::max();
  1668. }
  1669. // If there is no range then return early to avoid looping over the cache.
  1670. if (p0 == p1) {
  1671. return;
  1672. }
  1673. // for Mamba-like or RWKV models, only the pos needs to be changed
  1674. if (0 <= seq_id && seq_id < (int64_t) size) {
  1675. const int32_t tail_id = cells[seq_id].tail;
  1676. if (tail_id >= 0) {
  1677. kv_cell & cell = cells[tail_id];
  1678. if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
  1679. cell.pos /= d;
  1680. }
  1681. }
  1682. }
  1683. }
  1684. llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
  1685. llama_pos result = std::numeric_limits<llama_pos>::max();
  1686. for (uint32_t i = 0; i < size; ++i) {
  1687. if (cells[i].has_seq_id(seq_id)) {
  1688. result = std::min(result, cells[i].pos);
  1689. }
  1690. }
  1691. if (result == std::numeric_limits<llama_pos>::max()) {
  1692. result = -1;
  1693. }
  1694. return result;
  1695. }
  1696. llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
  1697. llama_pos result = -1;
  1698. for (uint32_t i = 0; i < size; ++i) {
  1699. if (cells[i].has_seq_id(seq_id)) {
  1700. result = std::max(result, cells[i].pos);
  1701. }
  1702. }
  1703. return result;
  1704. }
  1705. void llama_kv_cache_recurrent::restore() {
  1706. if (pending.ranges.empty()) {
  1707. return;
  1708. }
  1709. seq_rm(-1, -1, -1);
  1710. }
  1711. void llama_kv_cache_recurrent::commit() {
  1712. pending.ranges.clear();
  1713. }
  1714. bool llama_kv_cache_recurrent::update(llama_context & lctx) {
  1715. GGML_UNUSED(lctx);
  1716. return false;
  1717. }
  1718. void llama_kv_cache_recurrent::defrag_sched(float thold) {
  1719. GGML_UNUSED(thold);
  1720. // noop
  1721. }
  1722. void llama_kv_cache_recurrent::set_full() {
  1723. n = size;
  1724. head = 0;
  1725. }
  1726. llama_sbatch llama_kv_cache_recurrent::sbatch_init(
  1727. const llama_batch & batch,
  1728. bool logits_all) {
  1729. return llama_sbatch(batch, hparams.n_embd, false, logits_all);
  1730. }
  1731. llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
  1732. if (embd_pooled) {
  1733. // Pooled embeddings cannot be split across ubatches (yet)
  1734. return sbatch.split_seq(n_ubatch);
  1735. }
  1736. return sbatch.split_equal(n_ubatch);
  1737. }
  1738. bool llama_kv_cache_recurrent::find_slot(
  1739. const llama_ubatch & ubatch) {
  1740. const uint32_t n_tokens = ubatch.n_tokens;
  1741. const uint32_t n_seqs = ubatch.n_seqs;
  1742. const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
  1743. // if we have enough unused cells before the current head ->
  1744. // better to start searching from the beginning of the cache, hoping to fill it
  1745. if (head > used + 2*n_tokens) {
  1746. head = 0;
  1747. }
  1748. // For recurrent state architectures (like Mamba or RWKV),
  1749. // each cache cell can store the state for a whole sequence.
  1750. // A slot should be always be contiguous.
  1751. // can only process batches with an equal number of new tokens in each sequence
  1752. GGML_ASSERT(ubatch.equal_seqs);
  1753. int32_t min = size - 1;
  1754. int32_t max = 0;
  1755. // everything should fit if all seq_ids are smaller than the max
  1756. for (uint32_t s = 0; s < n_seqs; ++s) {
  1757. const uint32_t n_seq_id = ubatch.n_seq_id[s];
  1758. for (uint32_t j = 0; j < n_seq_id; ++j) {
  1759. const llama_seq_id seq_id = ubatch.seq_id[s][j];
  1760. if (seq_id < 0 || (uint32_t) seq_id >= size) {
  1761. // too big seq_id
  1762. // TODO: would it be possible to resize the cache instead?
  1763. LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
  1764. return false;
  1765. }
  1766. if (j > 0) {
  1767. kv_cell & seq = cells[seq_id];
  1768. if (seq.tail >= 0) {
  1769. kv_cell & cell = cells[seq.tail];
  1770. // clear cells from seq_ids that become shared
  1771. // (should not normally happen, but let's handle it anyway)
  1772. cell.seq_id.erase(seq_id);
  1773. seq.tail = -1;
  1774. if (cell.seq_id.empty()) {
  1775. cell.pos = -1;
  1776. cell.src = -1;
  1777. used -= 1;
  1778. }
  1779. }
  1780. }
  1781. }
  1782. }
  1783. #ifndef NDEBUG
  1784. {
  1785. std::vector<int32_t> tails_verif;
  1786. tails_verif.assign(size, -1);
  1787. for (uint32_t i = 0; i < size; ++i) {
  1788. kv_cell & cell = cells[i];
  1789. for (llama_seq_id seq_id : cell.seq_id) {
  1790. if (tails_verif[seq_id] != -1) {
  1791. LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
  1792. }
  1793. tails_verif[seq_id] = i;
  1794. }
  1795. }
  1796. for (uint32_t i = 0; i < size; ++i) {
  1797. if (tails_verif[i] != cells[i].tail) {
  1798. LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
  1799. }
  1800. }
  1801. }
  1802. #endif
  1803. // find next empty cell
  1804. uint32_t next_empty_cell = head;
  1805. for (uint32_t i = 0; i < size; ++i) {
  1806. if (next_empty_cell >= size) { next_empty_cell -= size; }
  1807. kv_cell & cell = cells[next_empty_cell];
  1808. if (cell.is_empty()) { break; }
  1809. next_empty_cell += 1;
  1810. }
  1811. // find usable cell range
  1812. for (uint32_t s = 0; s < n_seqs; ++s) {
  1813. const llama_seq_id seq_id = ubatch.seq_id[s][0];
  1814. kv_cell & seq_meta = cells[seq_id];
  1815. bool has_cell = false;
  1816. if (seq_meta.tail >= 0) {
  1817. kv_cell & cell = cells[seq_meta.tail];
  1818. GGML_ASSERT(cell.has_seq_id(seq_id));
  1819. // does this seq_id "own" the cell?
  1820. if (cell.seq_id.size() == 1) { has_cell = true; }
  1821. }
  1822. if (!has_cell) {
  1823. kv_cell & empty_cell = cells[next_empty_cell];
  1824. GGML_ASSERT(empty_cell.is_empty());
  1825. // copy old tail into the empty cell
  1826. if (seq_meta.tail >= 0) {
  1827. kv_cell & orig_cell = cells[seq_meta.tail];
  1828. empty_cell.pos = orig_cell.pos;
  1829. empty_cell.src = orig_cell.src;
  1830. orig_cell.seq_id.erase(seq_id);
  1831. empty_cell.seq_id.insert(seq_id); // will be overwritten
  1832. }
  1833. seq_meta.tail = next_empty_cell;
  1834. // find next empty cell
  1835. if (s + 1 < n_seqs) {
  1836. next_empty_cell += 1;
  1837. for (uint32_t i = 0; i < size; ++i) {
  1838. if (next_empty_cell >= size) { next_empty_cell -= size; }
  1839. kv_cell & cell = cells[next_empty_cell];
  1840. if (cell.is_empty()) { break; }
  1841. next_empty_cell += 1;
  1842. }
  1843. }
  1844. }
  1845. if (min > seq_meta.tail) { min = seq_meta.tail; }
  1846. if (max < seq_meta.tail) { max = seq_meta.tail; }
  1847. }
  1848. // gather and re-order
  1849. for (uint32_t s = 0; s < n_seqs; ++s) {
  1850. int32_t dst_id = s + min;
  1851. int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
  1852. if (dst_id != src_id) {
  1853. kv_cell & dst_cell = cells[dst_id];
  1854. kv_cell & src_cell = cells[src_id];
  1855. std::swap(dst_cell.pos, src_cell.pos);
  1856. std::swap(dst_cell.src, src_cell.src);
  1857. std::swap(dst_cell.seq_id, src_cell.seq_id);
  1858. // swap tails (assuming they NEVER overlap)
  1859. for (const llama_seq_id seq_id : src_cell.seq_id) {
  1860. cells[seq_id].tail = src_id;
  1861. }
  1862. for (const llama_seq_id seq_id : dst_cell.seq_id) {
  1863. cells[seq_id].tail = dst_id;
  1864. }
  1865. }
  1866. }
  1867. // update the pos of the used seqs
  1868. for (uint32_t s = 0; s < n_seqs; ++s) {
  1869. const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
  1870. int32_t cell_id = s + min;
  1871. kv_cell & cell = cells[cell_id];
  1872. if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
  1873. // What should happen when the pos backtracks or skips a value?
  1874. // Clearing the state mid-batch would require special-casing which isn't done.
  1875. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
  1876. __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
  1877. }
  1878. cell.pos = last_pos;
  1879. cell.seq_id.clear();
  1880. for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
  1881. const llama_seq_id seq_id = ubatch.seq_id[s][j];
  1882. cell.seq_id.insert(seq_id);
  1883. cells[seq_id].tail = cell_id;
  1884. }
  1885. }
  1886. // allow getting the range of used cells, from head to head + n
  1887. head = min;
  1888. n = max - min + 1;
  1889. used = std::count_if(cells.begin(), cells.end(),
  1890. [](const kv_cell & cell){ return !cell.is_empty(); });
  1891. // sanity check
  1892. return n >= n_seqs;
  1893. }
  1894. int32_t llama_kv_cache_recurrent::get_n_tokens() const {
  1895. int32_t result = 0;
  1896. for (uint32_t i = 0; i < size; i++) {
  1897. result += cells[i].seq_id.size();
  1898. }
  1899. return result;
  1900. }
  1901. int32_t llama_kv_cache_recurrent::get_used_cells() const {
  1902. return used;
  1903. }
  1904. llama_pos llama_kv_cache_recurrent::get_pos_max() const {
  1905. llama_pos pos_max = -1;
  1906. for (const auto & cell : cells) {
  1907. pos_max = std::max(pos_max, cell.pos);
  1908. }
  1909. return pos_max;
  1910. }
  1911. bool llama_kv_cache_recurrent::get_can_shift() const {
  1912. return false;
  1913. }
  1914. int32_t llama_kv_cache_recurrent::s_copy(int i) const {
  1915. const uint32_t cell_id = i + head;
  1916. //////////////////////////////////////////////
  1917. // TODO: this should not mutate the KV cache !
  1918. kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
  1919. // prevent out-of-bound sources
  1920. if (cell.src < 0 || (uint32_t) cell.src >= size) {
  1921. cell.src = cell_id;
  1922. }
  1923. int32_t res = cell.src;
  1924. // TODO: do not mutate the KV cache
  1925. // ensure copy only happens once
  1926. if (cell.src != (int32_t) cell_id) {
  1927. cell.src = cell_id;
  1928. }
  1929. return res;
  1930. }
  1931. float llama_kv_cache_recurrent::s_mask(int i) const {
  1932. const uint32_t cell_id = i + head;
  1933. //////////////////////////////////////////////
  1934. // TODO: this should not mutate the KV cache !
  1935. kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
  1936. float res = (float) (cell.src >= 0);
  1937. // only clear once
  1938. if (cell.src < 0) {
  1939. cell.src = cell_id;
  1940. }
  1941. return res;
  1942. }
  1943. uint32_t llama_kv_cache_recurrent::cell_max() const {
  1944. for (uint32_t i = size; i > 0; --i) {
  1945. const kv_cell & cell = cells[i - 1];
  1946. if (cell.pos >= 0 && !cell.is_empty()) {
  1947. return i;
  1948. }
  1949. }
  1950. return 0;
  1951. }
  1952. size_t llama_kv_cache_recurrent::total_size() const {
  1953. size_t size = 0;
  1954. for (const auto & buf : bufs) {
  1955. size += ggml_backend_buffer_get_size(buf.get());
  1956. }
  1957. return size;
  1958. }
  1959. size_t llama_kv_cache_recurrent::size_k_bytes() const {
  1960. size_t size_k_bytes = 0;
  1961. for (const auto & k : k_l) {
  1962. size_k_bytes += ggml_nbytes(k);
  1963. }
  1964. return size_k_bytes;
  1965. }
  1966. size_t llama_kv_cache_recurrent::size_v_bytes() const {
  1967. size_t size_v_bytes = 0;
  1968. for (const auto & v : v_l) {
  1969. size_v_bytes += ggml_nbytes(v);
  1970. }
  1971. return size_v_bytes;
  1972. }
  1973. void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
  1974. std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
  1975. uint32_t cell_count = 0;
  1976. // Count the number of cells with the specified seq_id
  1977. // Find all the ranges of cells with this seq id (or all, when -1)
  1978. uint32_t cell_range_begin = size;
  1979. for (uint32_t i = 0; i < size; ++i) {
  1980. const auto & cell = cells[i];
  1981. if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
  1982. ++cell_count;
  1983. if (cell_range_begin == size) {
  1984. cell_range_begin = i;
  1985. }
  1986. } else {
  1987. if (cell_range_begin != size) {
  1988. cell_ranges.emplace_back(cell_range_begin, i);
  1989. cell_range_begin = size;
  1990. }
  1991. }
  1992. }
  1993. if (cell_range_begin != size) {
  1994. cell_ranges.emplace_back(cell_range_begin, size);
  1995. }
  1996. // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
  1997. uint32_t cell_count_check = 0;
  1998. for (const auto & range : cell_ranges) {
  1999. cell_count_check += range.second - range.first;
  2000. }
  2001. GGML_ASSERT(cell_count == cell_count_check);
  2002. io.write(&cell_count, sizeof(cell_count));
  2003. state_write_meta(io, cell_ranges, seq_id);
  2004. state_write_data(io, cell_ranges);
  2005. }
  2006. void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
  2007. uint32_t cell_count;
  2008. io.read_to(&cell_count, sizeof(cell_count));
  2009. bool res = true;
  2010. res = res && state_read_meta(io, cell_count, seq_id);
  2011. res = res && state_read_data(io, cell_count);
  2012. if (!res) {
  2013. if (seq_id == -1) {
  2014. clear();
  2015. } else {
  2016. seq_rm(seq_id, -1, -1);
  2017. }
  2018. throw std::runtime_error("failed to restore kv cache");
  2019. }
  2020. }
  2021. void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
  2022. for (const auto & range : cell_ranges) {
  2023. for (uint32_t i = range.first; i < range.second; ++i) {
  2024. const auto & cell = cells[i];
  2025. const llama_pos pos = cell.pos;
  2026. const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
  2027. io.write(&pos, sizeof(pos));
  2028. io.write(&n_seq_id, sizeof(n_seq_id));
  2029. if (n_seq_id) {
  2030. for (auto seq_id : cell.seq_id) {
  2031. io.write(&seq_id, sizeof(seq_id));
  2032. }
  2033. }
  2034. }
  2035. }
  2036. }
  2037. void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
  2038. const uint32_t v_trans = 0;
  2039. const uint32_t n_layer = hparams.n_layer;
  2040. io.write(&v_trans, sizeof(v_trans));
  2041. io.write(&n_layer, sizeof(n_layer));
  2042. std::vector<uint8_t> tmp_buf;
  2043. // Iterate and write all the keys first, each row is a cell
  2044. // Get whole range at a time
  2045. for (uint32_t il = 0; il < n_layer; ++il) {
  2046. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
  2047. // Write key type
  2048. const int32_t k_type_i = (int32_t)k_l[il]->type;
  2049. io.write(&k_type_i, sizeof(k_type_i));
  2050. // Write row size of key
  2051. const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
  2052. io.write(&k_size_row, sizeof(k_size_row));
  2053. // Read each range of cells of k_size length each into tmp_buf and write out
  2054. for (const auto & range : cell_ranges) {
  2055. const size_t range_size = range.second - range.first;
  2056. const size_t buf_size = range_size * k_size_row;
  2057. io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
  2058. }
  2059. }
  2060. if (!v_trans) {
  2061. for (uint32_t il = 0; il < n_layer; ++il) {
  2062. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  2063. // Write value type
  2064. const int32_t v_type_i = (int32_t)v_l[il]->type;
  2065. io.write(&v_type_i, sizeof(v_type_i));
  2066. // Write row size of value
  2067. const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
  2068. io.write(&v_size_row, sizeof(v_size_row));
  2069. // Read each range of cells of v_size length each into tmp_buf and write out
  2070. for (const auto & range : cell_ranges) {
  2071. const size_t range_size = range.second - range.first;
  2072. const size_t buf_size = range_size * v_size_row;
  2073. io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
  2074. }
  2075. }
  2076. } else {
  2077. // When v is transposed, we also need the element size and get the element ranges from each row
  2078. const uint32_t kv_size = size;
  2079. for (uint32_t il = 0; il < n_layer; ++il) {
  2080. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  2081. // Write value type
  2082. const int32_t v_type_i = (int32_t)v_l[il]->type;
  2083. io.write(&v_type_i, sizeof(v_type_i));
  2084. // Write element size
  2085. const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
  2086. io.write(&v_size_el, sizeof(v_size_el));
  2087. // Write GQA embedding size
  2088. io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
  2089. // For each row, we get the element values of each cell
  2090. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  2091. // Read each range of cells of v_size_el length each into tmp_buf and write out
  2092. for (const auto & range : cell_ranges) {
  2093. const size_t range_size = range.second - range.first;
  2094. const size_t src_offset = (range.first + j * kv_size) * v_size_el;
  2095. const size_t buf_size = range_size * v_size_el;
  2096. io.write_tensor(v_l[il], src_offset, buf_size);
  2097. }
  2098. }
  2099. }
  2100. }
  2101. }
  2102. bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
  2103. if (dest_seq_id != -1) {
  2104. // single sequence
  2105. seq_rm(dest_seq_id, -1, -1);
  2106. llama_sbatch sbatch;
  2107. llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
  2108. batch.n_tokens = cell_count;
  2109. batch.n_seq_tokens = cell_count;
  2110. batch.n_seqs = 1;
  2111. for (uint32_t i = 0; i < cell_count; ++i) {
  2112. llama_pos pos;
  2113. uint32_t n_seq_id;
  2114. io.read_to(&pos, sizeof(pos));
  2115. io.read_to(&n_seq_id, sizeof(n_seq_id));
  2116. if (n_seq_id != 0) {
  2117. LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
  2118. return false;
  2119. }
  2120. batch.pos[i] = pos;
  2121. }
  2122. batch.n_seq_id[0] = 1;
  2123. batch.seq_id[0] = &dest_seq_id;
  2124. if (!find_slot(batch)) {
  2125. LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
  2126. return false;
  2127. }
  2128. commit();
  2129. // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
  2130. // Assume that this is one contiguous block of cells
  2131. GGML_ASSERT(head + cell_count <= size);
  2132. GGML_ASSERT(cells[head].pos == batch.pos[0]);
  2133. GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
  2134. GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
  2135. GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
  2136. } else {
  2137. // whole KV cache restore
  2138. if (cell_count > size) {
  2139. LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
  2140. return false;
  2141. }
  2142. clear();
  2143. for (uint32_t i = 0; i < cell_count; ++i) {
  2144. kv_cell & cell = cells[i];
  2145. llama_pos pos;
  2146. uint32_t n_seq_id;
  2147. io.read_to(&pos, sizeof(pos));
  2148. io.read_to(&n_seq_id, sizeof(n_seq_id));
  2149. cell.pos = pos;
  2150. for (uint32_t j = 0; j < n_seq_id; ++j) {
  2151. llama_seq_id seq_id;
  2152. io.read_to(&seq_id, sizeof(seq_id));
  2153. // TODO: llama_kv_cache_recurrent should have a notion of max sequences
  2154. //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
  2155. if (seq_id < 0) {
  2156. //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
  2157. LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
  2158. return false;
  2159. }
  2160. cell.seq_id.insert(seq_id);
  2161. int32_t & tail = cells[seq_id].tail;
  2162. if (tail != -1) {
  2163. LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
  2164. return false;
  2165. }
  2166. tail = i;
  2167. }
  2168. }
  2169. head = 0;
  2170. used = cell_count;
  2171. }
  2172. for (uint32_t i = 0; i < cell_count; ++i) {
  2173. uint32_t cell_id = head + i;
  2174. // make sure the recurrent states will keep their restored state
  2175. cells[cell_id].src = cell_id;
  2176. }
  2177. return true;
  2178. }
  2179. bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
  2180. uint32_t v_trans;
  2181. uint32_t n_layer;
  2182. io.read_to(&v_trans, sizeof(v_trans));
  2183. io.read_to(&n_layer, sizeof(n_layer));
  2184. if (n_layer != hparams.n_layer) {
  2185. LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
  2186. return false;
  2187. }
  2188. if (cell_count > size) {
  2189. LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
  2190. return false;
  2191. }
  2192. if (false != (bool) v_trans) {
  2193. LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
  2194. return false;
  2195. }
  2196. // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
  2197. for (uint32_t il = 0; il < n_layer; ++il) {
  2198. const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
  2199. // Read type of key
  2200. int32_t k_type_i_ref;
  2201. io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
  2202. const int32_t k_type_i = (int32_t) k_l[il]->type;
  2203. if (k_type_i != k_type_i_ref) {
  2204. LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
  2205. return false;
  2206. }
  2207. // Read row size of key
  2208. uint64_t k_size_row_ref;
  2209. io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
  2210. const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
  2211. if (k_size_row != k_size_row_ref) {
  2212. LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
  2213. return false;
  2214. }
  2215. if (cell_count) {
  2216. // Read and set the keys for the whole cell range
  2217. ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
  2218. }
  2219. }
  2220. if (!v_trans) {
  2221. for (uint32_t il = 0; il < n_layer; ++il) {
  2222. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  2223. // Read type of value
  2224. int32_t v_type_i_ref;
  2225. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  2226. const int32_t v_type_i = (int32_t)v_l[il]->type;
  2227. if (v_type_i != v_type_i_ref) {
  2228. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  2229. return false;
  2230. }
  2231. // Read row size of value
  2232. uint64_t v_size_row_ref;
  2233. io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
  2234. const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
  2235. if (v_size_row != v_size_row_ref) {
  2236. LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
  2237. return false;
  2238. }
  2239. if (cell_count) {
  2240. // Read and set the values for the whole cell range
  2241. ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
  2242. }
  2243. }
  2244. } else {
  2245. // For each layer, read the values for each cell (transposed)
  2246. for (uint32_t il = 0; il < n_layer; ++il) {
  2247. const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
  2248. // Read type of value
  2249. int32_t v_type_i_ref;
  2250. io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
  2251. const int32_t v_type_i = (int32_t)v_l[il]->type;
  2252. if (v_type_i != v_type_i_ref) {
  2253. LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
  2254. return false;
  2255. }
  2256. // Read element size of value
  2257. uint32_t v_size_el_ref;
  2258. io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
  2259. const size_t v_size_el = ggml_type_size(v_l[il]->type);
  2260. if (v_size_el != v_size_el_ref) {
  2261. LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
  2262. return false;
  2263. }
  2264. // Read GQA embedding size
  2265. uint32_t n_embd_v_gqa_ref;
  2266. io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
  2267. if (n_embd_v_gqa != n_embd_v_gqa_ref) {
  2268. LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
  2269. return false;
  2270. }
  2271. if (cell_count) {
  2272. // For each row in the transposed matrix, read the values for the whole cell range
  2273. for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
  2274. const size_t dst_offset = (head + j * size) * v_size_el;
  2275. ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
  2276. }
  2277. }
  2278. }
  2279. }
  2280. return true;
  2281. }