llama-context.cpp 126 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733
  1. #include "llama-context.h"
  2. #include "llama-arch.h"
  3. #include "llama-impl.h"
  4. #include "llama-batch.h"
  5. #include "llama-io.h"
  6. #include "llama-memory.h"
  7. #include "llama-mmap.h"
  8. #include "llama-model.h"
  9. #include <cinttypes>
  10. #include <cmath>
  11. #include <cstring>
  12. #include <limits>
  13. #include <stdexcept>
  14. //
  15. // llama_context
  16. //
  17. llama_context::llama_context(
  18. const llama_model & model,
  19. llama_context_params params) :
  20. model(model),
  21. balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
  22. // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
  23. // may need to be backend-dependent
  24. LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
  25. t_start_us = model.t_start_us;
  26. t_load_us = model.t_load_us;
  27. const auto & hparams = model.hparams;
  28. cparams.n_seq_max = std::max(1u, params.n_seq_max);
  29. if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
  30. throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
  31. }
  32. cparams.n_threads = params.n_threads;
  33. cparams.n_threads_batch = params.n_threads_batch;
  34. cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
  35. cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
  36. cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
  37. cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
  38. cparams.embeddings = params.embeddings;
  39. cparams.offload_kqv = params.offload_kqv;
  40. cparams.no_perf = params.no_perf;
  41. cparams.pooling_type = params.pooling_type;
  42. cparams.warmup = false;
  43. cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
  44. cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
  45. cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
  46. cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
  47. hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
  48. hparams.n_ctx_train;
  49. cparams.cb_eval = params.cb_eval;
  50. cparams.cb_eval_user_data = params.cb_eval_user_data;
  51. // Initialize backend samplers here so they are part of the sampling graph
  52. // before the reserve passes run later in this function. This avoids a later
  53. // re-reserve when graph nodes change.
  54. if (params.samplers != nullptr && params.n_samplers > 0) {
  55. for (size_t i = 0; i < params.n_samplers; ++i) {
  56. const auto & config = params.samplers[i];
  57. if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
  58. throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
  59. }
  60. if (set_sampler(config.seq_id, config.sampler)) {
  61. const int n_samplers = llama_sampler_chain_n(config.sampler);
  62. LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
  63. }
  64. }
  65. }
  66. auto rope_scaling_type = params.rope_scaling_type;
  67. if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
  68. rope_scaling_type = hparams.rope_scaling_type_train;
  69. }
  70. if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
  71. cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
  72. }
  73. if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
  74. cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
  75. }
  76. if (cparams.yarn_ext_factor != 0) {
  77. static auto get_mscale = [](float scale, float mscale) {
  78. return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
  79. };
  80. const float factor = 1.0f / cparams.rope_freq_scale;
  81. // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
  82. if (hparams.rope_yarn_log_mul != 0.0f) {
  83. // note: here we assume `mscale == 1.0f`
  84. // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
  85. float mscale = 1.0f;
  86. const float mscale_all_dims = hparams.rope_yarn_log_mul;
  87. // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
  88. // special-case DEEPSEEK v2:
  89. // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
  90. if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
  91. mscale = mscale_all_dims;
  92. }
  93. cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
  94. LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
  95. __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
  96. } else {
  97. cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
  98. }
  99. // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
  100. // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
  101. //
  102. // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
  103. // https://github.com/ggml-org/llama.cpp/pull/17945
  104. cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
  105. }
  106. cparams.yarn_attn_factor *= hparams.rope_attn_factor;
  107. if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
  108. if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
  109. cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
  110. } else {
  111. cparams.pooling_type = hparams.pooling_type;
  112. }
  113. }
  114. if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
  115. cparams.causal_attn = hparams.causal_attn;
  116. } else {
  117. cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
  118. }
  119. cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
  120. cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
  121. // with causal attention, the batch size is limited by the context size
  122. cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
  123. cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
  124. cparams.op_offload = params.op_offload;
  125. cparams.kv_unified = params.kv_unified;
  126. // intialized later
  127. cparams.pipeline_parallel = false;
  128. {
  129. const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
  130. graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
  131. if (graph_reuse_disable) {
  132. LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
  133. }
  134. }
  135. // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
  136. cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
  137. if (cparams.kv_unified) {
  138. cparams.n_ctx_seq = cparams.n_ctx;
  139. } else {
  140. cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
  141. cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
  142. if (cparams.n_ctx_seq == 0) {
  143. throw std::runtime_error("n_ctx_seq == 0");
  144. }
  145. if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
  146. cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
  147. LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
  148. }
  149. }
  150. LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
  151. LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
  152. LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
  153. LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
  154. LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
  155. LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
  156. LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
  157. LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
  158. LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
  159. LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
  160. if (cparams.n_ctx_seq < hparams.n_ctx_train) {
  161. LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
  162. __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
  163. }
  164. if (cparams.n_ctx_seq > hparams.n_ctx_train) {
  165. LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
  166. __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
  167. }
  168. if (!hparams.vocab_only) {
  169. // GPU backends
  170. for (auto * dev : model.devices) {
  171. ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
  172. if (backend == nullptr) {
  173. throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
  174. }
  175. backends.emplace_back(backend);
  176. }
  177. // add ACCEL backends (such as BLAS)
  178. for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
  179. ggml_backend_dev_t dev = ggml_backend_dev_get(i);
  180. if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
  181. ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
  182. if (backend == nullptr) {
  183. throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
  184. }
  185. backends.emplace_back(backend);
  186. }
  187. }
  188. // add CPU backend
  189. backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
  190. if (backend_cpu == nullptr) {
  191. throw std::runtime_error("failed to initialize CPU backend");
  192. }
  193. backends.emplace_back(backend_cpu);
  194. // create a list of the set_n_threads functions in the backends
  195. for (auto & backend : backends) {
  196. ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
  197. ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
  198. if (reg) {
  199. auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
  200. if (ggml_backend_set_n_threads_fn) {
  201. set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
  202. }
  203. }
  204. }
  205. llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
  206. // graph outputs buffer
  207. {
  208. // resized during inference when a batch uses more outputs
  209. // Create a dummy batch for initialization.
  210. llama_batch dummy_batch = {};
  211. dummy_batch.n_tokens = 0;
  212. if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
  213. throw std::runtime_error("failed to reserve initial output buffer");
  214. }
  215. LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
  216. ggml_backend_buffer_name (buf_output.get()),
  217. ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
  218. }
  219. }
  220. // init the memory module
  221. if (!hparams.vocab_only) {
  222. llama_memory_params params_mem = {
  223. /*.type_k =*/ params.type_k,
  224. /*.type_v =*/ params.type_v,
  225. /*.swa_full =*/ params.swa_full,
  226. };
  227. memory.reset(model.create_memory(params_mem, cparams));
  228. }
  229. // init backends
  230. if (!hparams.vocab_only) {
  231. LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
  232. backend_buft.clear();
  233. backend_ptrs.clear();
  234. backend_buf_exp_size.clear();
  235. for (auto & backend : backends) {
  236. auto * buft = ggml_backend_get_default_buffer_type(backend.get());
  237. auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
  238. if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
  239. // use the host buffer of the first device CPU for faster transfer of the intermediate state
  240. auto * dev = model.devices[0];
  241. auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
  242. if (host_buft) {
  243. buft = host_buft;
  244. }
  245. }
  246. backend_buft.push_back(buft);
  247. backend_ptrs.push_back(backend.get());
  248. backend_buf_exp_size.push_back(0);
  249. }
  250. LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
  251. // TODO: move these checks to ggml_backend_sched
  252. // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
  253. bool pipeline_parallel =
  254. model.n_devices() > 1 &&
  255. model.n_gpu_layers() > model.hparams.n_layer &&
  256. model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
  257. cparams.offload_kqv &&
  258. !model.has_tensor_overrides();
  259. // pipeline parallelism requires support for async compute and events in all devices
  260. if (pipeline_parallel) {
  261. for (auto & backend : backends) {
  262. auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
  263. if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
  264. // ignore CPU backend
  265. continue;
  266. }
  267. auto * dev = ggml_backend_get_device(backend.get());
  268. ggml_backend_dev_props props;
  269. ggml_backend_dev_get_props(dev, &props);
  270. if (!props.caps.async || !props.caps.events) {
  271. // device does not support async compute or events
  272. pipeline_parallel = false;
  273. break;
  274. }
  275. }
  276. }
  277. cparams.pipeline_parallel = pipeline_parallel;
  278. if (cparams.pipeline_parallel) {
  279. LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__);
  280. }
  281. sched_reserve();
  282. if (!cparams.flash_attn) {
  283. if (ggml_is_quantized(params.type_v)) {
  284. throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
  285. }
  286. }
  287. }
  288. // Initialize the full vocabulary token ids for backend samplers.
  289. {
  290. const int n_vocab = model.vocab.n_tokens();
  291. sampling.token_ids_full_vocab.resize(n_vocab);
  292. for (int i = 0; i < n_vocab; ++i) {
  293. sampling.token_ids_full_vocab[i] = i;
  294. }
  295. }
  296. }
  297. llama_context::~llama_context() {
  298. if (!model.hparams.no_alloc) {
  299. for (size_t i = 0; i < backend_ptrs.size(); ++i) {
  300. ggml_backend_t backend = backend_ptrs[i];
  301. ggml_backend_buffer_type_t buft = backend_buft[i];
  302. const size_t size_exp = backend_buf_exp_size[i];
  303. const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
  304. if (size_exp == size_act) {
  305. LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
  306. __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
  307. } else {
  308. LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
  309. __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
  310. }
  311. }
  312. }
  313. ggml_opt_free(opt_ctx);
  314. }
  315. void llama_context::sched_reserve() {
  316. if (!sched_need_reserve) {
  317. return;
  318. }
  319. sched_need_reserve = false;
  320. LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
  321. synchronize();
  322. const int64_t t_start_us = ggml_time_us();
  323. const uint32_t n_seqs = cparams.n_seq_max;
  324. const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
  325. const size_t max_nodes = this->graph_max_nodes(n_tokens);
  326. LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
  327. gf_res_prev.reset(new llm_graph_result(max_nodes));
  328. gf_res_reserve.reset(new llm_graph_result(max_nodes));
  329. sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload));
  330. llama_memory_context_ptr mctx;
  331. if (memory) {
  332. LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
  333. mctx = memory->init_full();
  334. if (!mctx) {
  335. throw std::runtime_error("failed to initialize memory module");
  336. }
  337. }
  338. // avoid reserving graphs with zero outputs - assume one output per sequence
  339. const int n_outputs = n_seqs;
  340. LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
  341. // resolve automatic Flash Attention use
  342. if (cparams.auto_fa) {
  343. auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
  344. if (!gf) {
  345. throw std::runtime_error("failed to split graph for Flash Attention check");
  346. }
  347. const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
  348. bool fa_device_mismatch = false;
  349. for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
  350. ggml_tensor * n = ggml_graph_node(gf, i);
  351. if (n->op != GGML_OP_FLASH_ATTN_EXT) {
  352. continue;
  353. }
  354. ggml_backend_dev_t device_fa = ggml_backend_get_device(
  355. ggml_backend_sched_get_tensor_backend(sched.get(), n));
  356. // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
  357. GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
  358. const int il = std::stoi(n->name + prefix_len);
  359. ggml_backend_dev_t device_kv = model.dev_layer(il);
  360. if (device_fa != device_kv) {
  361. LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
  362. "is assigned to device %s (usually due to missing support)\n",
  363. __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
  364. // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
  365. fa_device_mismatch = true;
  366. break;
  367. }
  368. }
  369. if (fa_device_mismatch) {
  370. cparams.flash_attn = false;
  371. LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
  372. } else {
  373. cparams.flash_attn = true;
  374. LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
  375. }
  376. cparams.auto_fa = false;
  377. }
  378. // reserve worst-case graph
  379. int n_splits_pp = -1;
  380. int n_nodes_pp = -1;
  381. int n_splits_tg = -1;
  382. int n_nodes_tg = -1;
  383. // reserve pp (prompt processing) graph first so that buffers are only allocated once
  384. {
  385. auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
  386. model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
  387. if (!gf) {
  388. if (cparams.pipeline_parallel) {
  389. LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
  390. cparams.pipeline_parallel = false;
  391. sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
  392. gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
  393. }
  394. if (!gf) {
  395. throw std::runtime_error("failed to allocate compute pp buffers");
  396. }
  397. }
  398. n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
  399. n_nodes_pp = ggml_graph_n_nodes(gf);
  400. }
  401. // reserve with tg (token generation) graph to get the number of splits and nodes
  402. {
  403. auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
  404. if (!gf) {
  405. throw std::runtime_error("failed to allocate compute tg buffers");
  406. }
  407. n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
  408. n_nodes_tg = ggml_graph_n_nodes(gf);
  409. }
  410. // reserve again with pp graph to avoid ggml-alloc reallocations during inference
  411. {
  412. // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
  413. //
  414. // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
  415. //
  416. auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
  417. if (!gf) {
  418. throw std::runtime_error("failed to allocate compute pp buffers");
  419. }
  420. }
  421. for (size_t i = 0; i < backend_ptrs.size(); ++i) {
  422. ggml_backend_t backend = backend_ptrs[i];
  423. ggml_backend_buffer_type_t buft = backend_buft[i];
  424. if (!model.hparams.no_alloc) {
  425. backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
  426. }
  427. if (backend_buf_exp_size[i] > 1) {
  428. LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
  429. ggml_backend_buft_name(buft),
  430. backend_buf_exp_size[i] / 1024.0 / 1024.0);
  431. }
  432. }
  433. if (n_nodes_pp == n_nodes_tg) {
  434. LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
  435. } else {
  436. LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
  437. }
  438. if (n_splits_pp == n_splits_tg) {
  439. LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
  440. } else {
  441. LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
  442. }
  443. const int64_t t_end_us = ggml_time_us();
  444. LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n",
  445. __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get()));
  446. }
  447. void llama_context::synchronize() {
  448. if (!sched) {
  449. return;
  450. }
  451. ggml_backend_sched_synchronize(sched.get());
  452. // FIXME: if multiple single tokens are evaluated without a synchronization,
  453. // the stats will be added to the prompt evaluation stats
  454. // this should only happen when using batch size 1 to evaluate a batch
  455. // add the evaluation to the stats
  456. if (n_queued_tokens == 1) {
  457. if (!cparams.no_perf) {
  458. t_eval_us += ggml_time_us() - t_compute_start_us;
  459. }
  460. n_eval++;
  461. } else if (n_queued_tokens > 1) {
  462. if (!cparams.no_perf) {
  463. t_p_eval_us += ggml_time_us() - t_compute_start_us;
  464. }
  465. n_p_eval += n_queued_tokens;
  466. }
  467. // get a more accurate load time, upon first eval
  468. if (n_queued_tokens > 0 && !has_evaluated_once) {
  469. t_load_us = ggml_time_us() - t_start_us;
  470. has_evaluated_once = true;
  471. }
  472. n_queued_tokens = 0;
  473. t_compute_start_us = 0;
  474. }
  475. const llama_model & llama_context::get_model() const {
  476. return model;
  477. }
  478. const llama_cparams & llama_context::get_cparams() const {
  479. return cparams;
  480. }
  481. ggml_backend_sched_t llama_context::get_sched() const {
  482. return sched.get();
  483. }
  484. uint32_t llama_context::n_ctx() const {
  485. return cparams.n_ctx;
  486. }
  487. uint32_t llama_context::n_ctx_seq() const {
  488. return cparams.n_ctx_seq;
  489. }
  490. uint32_t llama_context::n_batch() const {
  491. return cparams.n_batch;
  492. }
  493. uint32_t llama_context::n_ubatch() const {
  494. return cparams.n_ubatch;
  495. }
  496. uint32_t llama_context::n_seq_max() const {
  497. return cparams.n_seq_max;
  498. }
  499. uint32_t llama_context::n_threads() const {
  500. return cparams.n_threads;
  501. }
  502. uint32_t llama_context::n_threads_batch() const {
  503. return cparams.n_threads_batch;
  504. }
  505. llama_memory_t llama_context::get_memory() const {
  506. return memory.get();
  507. }
  508. bool llama_context::memory_update(bool optimize) {
  509. if (!memory) {
  510. return false;
  511. }
  512. {
  513. const auto mctx = memory->init_update(this, optimize);
  514. switch (mctx->get_status()) {
  515. case LLAMA_MEMORY_STATUS_SUCCESS:
  516. {
  517. // noop
  518. } break;
  519. case LLAMA_MEMORY_STATUS_NO_UPDATE:
  520. {
  521. // no updates need to be performed
  522. return false;
  523. }
  524. case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
  525. case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
  526. {
  527. LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
  528. return false;
  529. }
  530. }
  531. // reset the previous graph result to make sure that it won't be reused
  532. // TODO: change the mctx->apply() to return information if a graph reserve is needed
  533. // reset the graph result only if the memory module did reset the scheduler
  534. gf_res_prev->reset();
  535. if (!mctx->apply()) {
  536. LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
  537. }
  538. }
  539. // if the memory module did any computation, we have to reserve a new worst-case graph
  540. {
  541. const auto mctx = memory->init_full();
  542. if (!mctx) {
  543. throw std::runtime_error("failed to initialize memory context");
  544. }
  545. const uint32_t n_seqs = cparams.n_seq_max;
  546. const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
  547. auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
  548. if (!gf) {
  549. LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
  550. }
  551. }
  552. return true;
  553. }
  554. enum llama_pooling_type llama_context::pooling_type() const {
  555. return cparams.pooling_type;
  556. }
  557. float * llama_context::get_logits() {
  558. output_reorder();
  559. return logits;
  560. }
  561. int64_t llama_context::output_resolve_row(int32_t i) const {
  562. int64_t j = -1;
  563. // support negative indices (last output row)
  564. if (i < 0) {
  565. j = n_outputs + i;
  566. if (j < 0) {
  567. throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
  568. }
  569. } else if ((size_t) i >= output_ids.size()) {
  570. throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
  571. } else {
  572. // use output_ids to translate the batch token index into a row number
  573. // that holds this token's data.
  574. j = output_ids[i];
  575. }
  576. if (j < 0) {
  577. // the batch token was not configured to output anything
  578. throw std::runtime_error(format("batch.logits[%d] != true", i));
  579. }
  580. if (j >= n_outputs) {
  581. throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
  582. }
  583. return j;
  584. }
  585. float * llama_context::get_logits_ith(int32_t i) {
  586. int64_t j = -1;
  587. output_reorder();
  588. try {
  589. if (logits == nullptr) {
  590. throw std::runtime_error("no logits");
  591. }
  592. // TODO: use output_resolve_row()
  593. if (i < 0) {
  594. j = n_outputs + i;
  595. if (j < 0) {
  596. throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
  597. }
  598. } else if ((size_t) i >= output_ids.size()) {
  599. throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
  600. } else {
  601. j = output_ids[i];
  602. }
  603. if (j < 0) {
  604. throw std::runtime_error(format("batch.logits[%d] != true", i));
  605. }
  606. if (j >= n_outputs) {
  607. // This should not happen
  608. throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
  609. }
  610. return logits + j*model.vocab.n_tokens();
  611. } catch (const std::exception & err) {
  612. LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
  613. #ifndef NDEBUG
  614. GGML_ABORT("fatal error");
  615. #else
  616. return nullptr;
  617. #endif
  618. }
  619. }
  620. float * llama_context::get_embeddings() {
  621. output_reorder();
  622. return embd;
  623. }
  624. llama_token * llama_context::get_sampled_tokens() const{
  625. return sampling.sampled;
  626. }
  627. float * llama_context::get_embeddings_ith(int32_t i) {
  628. int64_t j = -1;
  629. output_reorder();
  630. try {
  631. if (embd == nullptr) {
  632. throw std::runtime_error("no embeddings");
  633. }
  634. // TODO: use output_resolve_row()
  635. if (i < 0) {
  636. j = n_outputs + i;
  637. if (j < 0) {
  638. throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
  639. }
  640. } else if ((size_t) i >= output_ids.size()) {
  641. throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
  642. } else {
  643. j = output_ids[i];
  644. }
  645. if (j < 0) {
  646. throw std::runtime_error(format("batch.logits[%d] != true", i));
  647. }
  648. if (j >= n_outputs) {
  649. // This should not happen
  650. throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
  651. }
  652. const uint32_t n_embd_out = model.hparams.get_n_embd_out();
  653. return embd + j*n_embd_out;
  654. } catch (const std::exception & err) {
  655. LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
  656. #ifndef NDEBUG
  657. GGML_ABORT("fatal error");
  658. #else
  659. return nullptr;
  660. #endif
  661. }
  662. }
  663. float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
  664. auto it = embd_seq.find(seq_id);
  665. if (it == embd_seq.end()) {
  666. return nullptr;
  667. }
  668. return it->second.data();
  669. }
  670. llama_token llama_context::get_sampled_token_ith(int32_t idx) {
  671. output_reorder();
  672. if (sampling.sampled == nullptr) {
  673. return LLAMA_TOKEN_NULL;
  674. }
  675. try {
  676. const int64_t row = output_resolve_row(idx);
  677. GGML_ASSERT(row < (int64_t) sampling.sampled_size);
  678. return sampling.sampled[row];
  679. } catch (const std::exception & err) {
  680. LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
  681. return LLAMA_TOKEN_NULL;
  682. }
  683. }
  684. float * llama_context::get_sampled_probs_ith(int32_t idx) {
  685. output_reorder();
  686. if (sampling.probs == nullptr) {
  687. return nullptr;
  688. }
  689. try {
  690. const int64_t row = output_resolve_row(idx);
  691. if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
  692. return nullptr;
  693. }
  694. return sampling.probs + row*model.vocab.n_tokens();
  695. } catch (const std::exception & err) {
  696. LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
  697. return nullptr;
  698. }
  699. }
  700. float * llama_context::get_sampled_logits_ith(int32_t idx) {
  701. output_reorder();
  702. if (sampling.logits == nullptr) {
  703. return nullptr;
  704. }
  705. try {
  706. const int64_t row = output_resolve_row(idx);
  707. if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
  708. return nullptr;
  709. }
  710. return sampling.logits + row*model.vocab.n_tokens();
  711. } catch (const std::exception & err) {
  712. LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
  713. return nullptr;
  714. }
  715. }
  716. const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
  717. output_reorder();
  718. try {
  719. const int64_t row = output_resolve_row(idx);
  720. if (sampling.candidates != nullptr &&
  721. (size_t) row < sampling.candidates_count.size() &&
  722. sampling.candidates_count[row] > 0) {
  723. return sampling.candidates + row*model.vocab.n_tokens();
  724. }
  725. } catch (const std::exception & err) {
  726. // fallback to full vocab list
  727. }
  728. return sampling.token_ids_full_vocab.data();
  729. }
  730. size_t llama_context::get_sampled_candidates_count(int32_t idx) {
  731. output_reorder();
  732. if (sampling.candidates == nullptr) {
  733. return 0;
  734. }
  735. try {
  736. const int64_t row = output_resolve_row(idx);
  737. if ((size_t) row >= sampling.candidates_count.size()) {
  738. return 0;
  739. }
  740. return sampling.candidates_count[row];
  741. } catch (const std::exception & err) {
  742. LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
  743. return 0;
  744. }
  745. }
  746. size_t llama_context::get_sampled_logits_count(int32_t idx) {
  747. output_reorder();
  748. if (sampling.logits == nullptr) {
  749. return model.vocab.n_tokens();
  750. }
  751. try {
  752. const int64_t row = output_resolve_row(idx);
  753. if ((size_t) row >= sampling.logits_count.size()) {
  754. return 0;
  755. }
  756. return sampling.logits_count[row];
  757. } catch (const std::exception & err) {
  758. LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
  759. return 0;
  760. }
  761. }
  762. size_t llama_context::get_sampled_probs_count(int32_t idx) {
  763. output_reorder();
  764. if (sampling.probs == nullptr) {
  765. return 0;
  766. }
  767. try {
  768. const int64_t row = output_resolve_row(idx);
  769. if ((size_t) row >= sampling.probs_count.size()) {
  770. return 0;
  771. }
  772. return sampling.probs_count[row];
  773. } catch (const std::exception & err) {
  774. LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
  775. return 0;
  776. }
  777. }
  778. void llama_context::attach_threadpool(
  779. ggml_threadpool_t threadpool,
  780. ggml_threadpool_t threadpool_batch) {
  781. LLAMA_LOG_DEBUG("%s: call\n", __func__);
  782. this->threadpool = threadpool;
  783. this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
  784. }
  785. void llama_context::detach_threadpool() {
  786. LLAMA_LOG_DEBUG("%s: call\n", __func__);
  787. this->threadpool = nullptr;
  788. this->threadpool_batch = nullptr;
  789. }
  790. void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
  791. LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
  792. cparams.n_threads = n_threads;
  793. cparams.n_threads_batch = n_threads_batch;
  794. }
  795. void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
  796. LLAMA_LOG_DEBUG("%s: call\n", __func__);
  797. this->abort_callback = abort_callback;
  798. this->abort_callback_data = abort_callback_data;
  799. for (auto & backend : backends) {
  800. auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
  801. auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
  802. if (set_abort_callback_fn) {
  803. set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
  804. }
  805. }
  806. }
  807. void llama_context::set_embeddings(bool value) {
  808. LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
  809. cparams.embeddings = value;
  810. // TODO: not sure yet if we want to reserve here
  811. //sched_need_reserve = true;
  812. }
  813. void llama_context::set_causal_attn(bool value) {
  814. LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
  815. if (cparams.causal_attn == value) {
  816. return;
  817. }
  818. cparams.causal_attn = value;
  819. sched_need_reserve = true;
  820. }
  821. void llama_context::set_warmup(bool value) {
  822. LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
  823. if (cparams.warmup == value) {
  824. return;
  825. }
  826. cparams.warmup = value;
  827. // warmups are usually with small batches, so no need to reserve
  828. //sched_need_reserve = true;
  829. }
  830. bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
  831. if (!sampler && sampling.samplers.count(seq_id) == 0) {
  832. return true;
  833. }
  834. LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
  835. const bool can_offload =
  836. sampler &&
  837. sampler->iface->backend_init &&
  838. sampler->iface->backend_apply &&
  839. llama_sampler_chain_n(sampler) > 0;
  840. if (sampler && can_offload) {
  841. ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
  842. auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
  843. if (host_buft) {
  844. buft = host_buft;
  845. }
  846. sampler->iface->backend_init(sampler, buft);
  847. sampling.samplers[seq_id] = sampler;
  848. sched_need_reserve = true;
  849. return true;
  850. }
  851. if (sampler && !can_offload) {
  852. LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
  853. if (sampling.samplers.count(seq_id) > 0) {
  854. sched_need_reserve = true;
  855. }
  856. sampling.samplers.erase(seq_id);
  857. return false;
  858. }
  859. sampling.samplers.erase(seq_id);
  860. sched_need_reserve = true;
  861. return true;
  862. }
  863. void llama_context::set_adapter_lora(
  864. llama_adapter_lora * adapter,
  865. float scale) {
  866. LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
  867. if (auto it = loras.find(adapter); it != loras.end()) {
  868. if (it->second == scale) {
  869. return;
  870. }
  871. }
  872. loras[adapter] = scale;
  873. sched_need_reserve = true;
  874. }
  875. bool llama_context::rm_adapter_lora(
  876. llama_adapter_lora * adapter) {
  877. LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
  878. auto it = loras.find(adapter);
  879. if (it != loras.end()) {
  880. loras.erase(it);
  881. sched_need_reserve = true;
  882. return true;
  883. }
  884. return false;
  885. }
  886. void llama_context::clear_adapter_lora() {
  887. LLAMA_LOG_DEBUG("%s: call\n", __func__);
  888. if (loras.empty()) {
  889. return;
  890. }
  891. loras.clear();
  892. sched_need_reserve = true;
  893. }
  894. bool llama_context::apply_adapter_cvec(
  895. const float * data,
  896. size_t len,
  897. int32_t n_embd,
  898. int32_t il_start,
  899. int32_t il_end) {
  900. LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
  901. // TODO: should we reserve?
  902. return cvec.apply(model, data, len, n_embd, il_start, il_end);
  903. }
  904. llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
  905. if (mctx && !mctx->apply()) {
  906. LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
  907. ret = GGML_STATUS_FAILED;
  908. return nullptr;
  909. }
  910. auto * res = gf_res_prev.get();
  911. auto * gf = res->get_gf();
  912. // the new graph parameters
  913. // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
  914. const auto gparams = graph_params(res, ubatch, mctx, gtype);
  915. if (!graph_reuse_disable && res->can_reuse(gparams)) {
  916. //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
  917. n_reused++;
  918. } else {
  919. res->reset();
  920. ggml_backend_sched_reset(sched.get());
  921. ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
  922. //const auto t_start_us = ggml_time_us();
  923. gf = model.build_graph(gparams);
  924. //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
  925. if (!gf) {
  926. LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
  927. ret = GGML_STATUS_FAILED;
  928. return nullptr;
  929. }
  930. if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
  931. LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
  932. ret = GGML_STATUS_ALLOC_FAILED;
  933. return nullptr;
  934. }
  935. }
  936. // set the input data for the input tensors
  937. {
  938. //const auto t_start_us = ggml_time_us();
  939. res->set_inputs(&ubatch);
  940. //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
  941. }
  942. const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
  943. if (status != GGML_STATUS_SUCCESS) {
  944. LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
  945. ret = status;
  946. return nullptr;
  947. }
  948. ret = GGML_STATUS_SUCCESS;
  949. return res;
  950. }
  951. int llama_context::encode(const llama_batch & batch_inp) {
  952. GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
  953. if (batch_inp.n_tokens == 0) {
  954. LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
  955. return -1;
  956. }
  957. const auto & hparams = model.hparams;
  958. const int64_t n_embd = hparams.n_embd_inp();
  959. const int64_t n_vocab = model.vocab.n_tokens();
  960. // note: during encode, we always pass the full sequence starting from pos = 0
  961. if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
  962. LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
  963. return -1;
  964. }
  965. const uint32_t n_tokens = balloc->get_n_tokens();
  966. // [TAG_NO_CACHE_PAD]
  967. // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
  968. const llama_ubatch ubatch = balloc->split_simple(n_tokens);
  969. // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
  970. GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
  971. if (t_compute_start_us == 0) {
  972. t_compute_start_us = ggml_time_us();
  973. }
  974. // TODO: this clear of the buffer can easily be forgotten - need something better
  975. embd_seq.clear();
  976. sched_reserve();
  977. n_queued_tokens += n_tokens;
  978. // reserve output buffer
  979. if (output_reserve(n_tokens, batch_inp) < n_tokens) {
  980. LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
  981. return -2;
  982. };
  983. for (uint32_t i = 0; i < n_tokens; ++i) {
  984. output_ids[i] = i;
  985. }
  986. n_outputs = n_tokens;
  987. const auto causal_attn_org = cparams.causal_attn;
  988. // always use non-causal attention for encoder graphs
  989. // TODO: this is a tmp solution until we have a proper way to support enc-dec models
  990. // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
  991. cparams.causal_attn = false;
  992. ggml_status status;
  993. const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
  994. cparams.causal_attn = causal_attn_org;
  995. if (!res) {
  996. switch (status) {
  997. case GGML_STATUS_ABORTED: return 2;
  998. case GGML_STATUS_ALLOC_FAILED: return -2;
  999. case GGML_STATUS_FAILED: return -3;
  1000. case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
  1001. }
  1002. }
  1003. auto * t_logits = res->get_logits();
  1004. auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
  1005. // extract logits
  1006. if (logits && t_logits) {
  1007. ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
  1008. GGML_ASSERT(backend_res != nullptr);
  1009. GGML_ASSERT(logits != nullptr);
  1010. ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
  1011. }
  1012. // extract embeddings
  1013. if (embd && t_embd) {
  1014. ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
  1015. GGML_ASSERT(backend_embd != nullptr);
  1016. switch (cparams.pooling_type) {
  1017. case LLAMA_POOLING_TYPE_NONE:
  1018. {
  1019. // extract token embeddings
  1020. GGML_ASSERT(embd != nullptr);
  1021. const uint32_t n_embd_out = hparams.get_n_embd_out();
  1022. GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
  1023. ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
  1024. } break;
  1025. case LLAMA_POOLING_TYPE_MEAN:
  1026. case LLAMA_POOLING_TYPE_CLS:
  1027. case LLAMA_POOLING_TYPE_LAST:
  1028. {
  1029. // extract sequence embeddings
  1030. auto & embd_seq_out = embd_seq;
  1031. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  1032. const llama_seq_id seq_id = ubatch.seq_id_unq[s];
  1033. const int32_t seq_idx = ubatch.seq_idx[seq_id];
  1034. embd_seq_out[seq_id].resize(n_embd);
  1035. ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
  1036. }
  1037. } break;
  1038. case LLAMA_POOLING_TYPE_RANK:
  1039. {
  1040. // extract the rerank score - n_cls_out floats per sequence
  1041. auto & embd_seq_out = embd_seq;
  1042. const uint32_t n_cls_out = hparams.n_cls_out;
  1043. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  1044. const llama_seq_id seq_id = ubatch.seq_id_unq[s];
  1045. const int32_t seq_idx = ubatch.seq_idx[seq_id];
  1046. embd_seq_out[seq_id].resize(n_cls_out);
  1047. ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
  1048. }
  1049. } break;
  1050. case LLAMA_POOLING_TYPE_UNSPECIFIED:
  1051. {
  1052. GGML_ABORT("unknown pooling type");
  1053. }
  1054. }
  1055. }
  1056. // TODO: hacky solution
  1057. if (model.arch == LLM_ARCH_T5 && t_embd) {
  1058. //cross.t_embd = t_embd;
  1059. synchronize();
  1060. cross.n_embd = t_embd->ne[0];
  1061. cross.n_enc = t_embd->ne[1];
  1062. cross.v_embd.resize(cross.n_embd*cross.n_enc);
  1063. memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
  1064. const auto & batch = balloc->get_batch();
  1065. // remember the sequence ids used during the encoding - needed for cross attention later
  1066. cross.seq_ids_enc.resize(n_tokens);
  1067. for (uint32_t i = 0; i < n_tokens; i++) {
  1068. cross.seq_ids_enc[i].clear();
  1069. for (int s = 0; s < batch.n_seq_id[i]; s++) {
  1070. const llama_seq_id seq_id = batch.seq_id[i][s];
  1071. cross.seq_ids_enc[i].insert(seq_id);
  1072. }
  1073. }
  1074. }
  1075. return 0;
  1076. }
  1077. static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
  1078. std::map<llama_seq_id, uint32_t> seq_to_row;
  1079. // how many output tokens we have seen so far for this ubatch.
  1080. uint32_t local = 0;
  1081. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  1082. // skip tokens that are not output.
  1083. if (!ubatch.output[i]) {
  1084. continue;
  1085. }
  1086. const llama_seq_id seq_id = ubatch.seq_id[i][0];
  1087. // row_offset is the number of output tokens before this ubatch.
  1088. seq_to_row[seq_id] = row_offset + local;
  1089. ++local;
  1090. }
  1091. return seq_to_row;
  1092. }
  1093. static void copy_tensor_async_ints(
  1094. const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
  1095. llama_token * sampled,
  1096. size_t sampled_size,
  1097. const std::map<llama_seq_id, uint32_t> & seq_to_row,
  1098. ggml_backend_sched_t sched) {
  1099. if (sampled == nullptr) {
  1100. return;
  1101. }
  1102. for (const auto & [seq_id, tensor] : tensor_map) {
  1103. auto it = seq_to_row.find(seq_id);
  1104. if (it == seq_to_row.end()) {
  1105. continue;
  1106. }
  1107. const uint32_t row = it->second;
  1108. GGML_ASSERT(row < sampled_size);
  1109. GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
  1110. ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
  1111. ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
  1112. }
  1113. }
  1114. static void copy_tensor_async_floats(
  1115. const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
  1116. float * dst,
  1117. size_t stride,
  1118. std::vector<uint32_t> & counts,
  1119. const std::map<llama_seq_id, uint32_t> & seq_to_row,
  1120. ggml_backend_sched_t sched) {
  1121. if (dst == nullptr) {
  1122. return;
  1123. }
  1124. for (const auto & [seq_id, tensor] : tensor_map) {
  1125. auto it = seq_to_row.find(seq_id);
  1126. if (it == seq_to_row.end()) {
  1127. continue;
  1128. }
  1129. const uint32_t row = it->second;
  1130. GGML_ASSERT(row < counts.size());
  1131. GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
  1132. ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
  1133. float * row_ptr = dst + (size_t) row * stride;
  1134. ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
  1135. // Update the actual number of logits/probabilities that were written for this row.
  1136. counts[row] = ggml_nelements(tensor);
  1137. }
  1138. }
  1139. static void copy_tensor_async_candidates(
  1140. const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
  1141. llama_token * dst,
  1142. size_t stride,
  1143. std::vector<uint32_t> & counts,
  1144. const std::map<llama_seq_id, uint32_t> & seq_to_row,
  1145. ggml_backend_sched_t sched) {
  1146. if (dst == nullptr) {
  1147. return;
  1148. }
  1149. for (const auto & [seq_id, tensor] : tensor_map) {
  1150. auto it = seq_to_row.find(seq_id);
  1151. if (it == seq_to_row.end()) {
  1152. continue;
  1153. }
  1154. const uint32_t row = it->second;
  1155. GGML_ASSERT(row < counts.size());
  1156. GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
  1157. ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
  1158. llama_token * row_ptr = dst + (size_t) row * stride;
  1159. ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
  1160. // Update the actual number of candidates that were written.
  1161. counts[row] = ggml_nelements(tensor);
  1162. }
  1163. }
  1164. int llama_context::decode(const llama_batch & batch_inp) {
  1165. GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
  1166. if (!memory) {
  1167. LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
  1168. return encode(batch_inp);
  1169. }
  1170. if (batch_inp.n_tokens == 0) {
  1171. LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
  1172. return -1;
  1173. }
  1174. const auto & vocab = model.vocab;
  1175. const auto & hparams = model.hparams;
  1176. const int64_t n_vocab = vocab.n_tokens();
  1177. const int64_t n_embd = hparams.n_embd_inp();
  1178. // when computing embeddings, all tokens are output
  1179. const bool output_all = cparams.embeddings;
  1180. const bool has_samplers = !sampling.samplers.empty();
  1181. const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
  1182. // TODO: avoid this workaround in the future
  1183. if (has_samplers && batch_inp.logits) {
  1184. std::vector<int32_t> seq_output_count(n_seq_max, 0);
  1185. for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
  1186. if (batch_inp.logits[i] == 0) {
  1187. continue;
  1188. }
  1189. const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
  1190. for (int32_t s = 0; s < ns; ++s) {
  1191. const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
  1192. seq_output_count[seq_id]++;
  1193. if (seq_output_count[seq_id] > 1) {
  1194. LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
  1195. __func__, seq_id, seq_output_count[seq_id]);
  1196. return -1;
  1197. }
  1198. }
  1199. }
  1200. }
  1201. if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
  1202. LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
  1203. return -1;
  1204. }
  1205. const uint32_t n_tokens_all = balloc->get_n_tokens();
  1206. const uint32_t n_outputs_all = balloc->get_n_outputs();
  1207. if (output_all) {
  1208. // require that all tokens are output
  1209. if (n_outputs_all != n_tokens_all) {
  1210. LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
  1211. __func__, n_outputs_all, n_tokens_all);
  1212. return -1;
  1213. }
  1214. }
  1215. GGML_ASSERT(n_tokens_all <= cparams.n_batch);
  1216. GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
  1217. if (t_compute_start_us == 0) {
  1218. t_compute_start_us = ggml_time_us();
  1219. }
  1220. n_queued_tokens += n_tokens_all;
  1221. // TODO: this clear of the buffer can easily be forgotten - need something better
  1222. embd_seq.clear();
  1223. output_swaps.clear();
  1224. sched_reserve();
  1225. bool did_optimize = false;
  1226. // handle any pending shifts/copies
  1227. memory_update(false);
  1228. llama_memory_context_ptr mctx;
  1229. while (true) {
  1230. mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
  1231. if (!mctx) {
  1232. return -2;
  1233. }
  1234. switch (mctx->get_status()) {
  1235. case LLAMA_MEMORY_STATUS_SUCCESS:
  1236. {
  1237. } break;
  1238. case LLAMA_MEMORY_STATUS_NO_UPDATE:
  1239. {
  1240. LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
  1241. return -2;
  1242. }
  1243. case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
  1244. {
  1245. if (!did_optimize) {
  1246. did_optimize = true;
  1247. if (memory_update(true)) {
  1248. LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
  1249. continue;
  1250. }
  1251. }
  1252. LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
  1253. return 1;
  1254. }
  1255. case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
  1256. {
  1257. LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
  1258. return -2;
  1259. }
  1260. }
  1261. break;
  1262. }
  1263. // reserve output buffer
  1264. if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
  1265. LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
  1266. return -2;
  1267. };
  1268. int64_t n_outputs_prev = 0;
  1269. do {
  1270. const auto & ubatch = mctx->get_ubatch();
  1271. // count the outputs in this ubatch
  1272. {
  1273. int32_t n_outputs_new = 0;
  1274. if (n_outputs_all == n_tokens_all) {
  1275. n_outputs_new = ubatch.n_tokens;
  1276. } else {
  1277. for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
  1278. n_outputs_new += (int32_t) (ubatch.output[i] != 0);
  1279. }
  1280. }
  1281. // needs to happen before the graph is built
  1282. n_outputs = n_outputs_new;
  1283. }
  1284. ggml_status status;
  1285. const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
  1286. if (!res) {
  1287. // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
  1288. llama_pos pos_min[LLAMA_MAX_SEQ];
  1289. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  1290. pos_min[s] = std::numeric_limits<llama_pos>::max();
  1291. }
  1292. for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
  1293. const auto & seq_id = ubatch.seq_id[i][0];
  1294. pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
  1295. }
  1296. for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
  1297. if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
  1298. continue;
  1299. }
  1300. LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
  1301. memory->seq_rm(s, pos_min[s], -1);
  1302. }
  1303. switch (status) {
  1304. case GGML_STATUS_ABORTED: return 2;
  1305. case GGML_STATUS_ALLOC_FAILED: return -2;
  1306. case GGML_STATUS_FAILED: return -3;
  1307. case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
  1308. }
  1309. }
  1310. // plot the computation graph in dot format (for debugging purposes)
  1311. //if (n_past%100 == 0) {
  1312. // ggml_graph_dump_dot(gf, NULL, "llama.dot");
  1313. //}
  1314. auto * t_logits = res->get_logits();
  1315. auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
  1316. if (t_embd && res->get_embd_pooled()) {
  1317. t_embd = res->get_embd_pooled();
  1318. }
  1319. // extract logits
  1320. // For multi-sequence batches that mix backend samplers and CPU sampler
  1321. // this is currently inefficient as we copy all logits even for the
  1322. // backend sampled tokens.
  1323. if (logits && t_logits && n_outputs > 0) {
  1324. ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
  1325. GGML_ASSERT(backend_res != nullptr);
  1326. GGML_ASSERT(logits != nullptr);
  1327. float * logits_out = logits + n_outputs_prev*n_vocab;
  1328. if (n_outputs) {
  1329. GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
  1330. GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
  1331. ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
  1332. }
  1333. }
  1334. // extract embeddings
  1335. if (embd && t_embd && n_outputs > 0) {
  1336. ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
  1337. GGML_ASSERT(backend_embd != nullptr);
  1338. switch (cparams.pooling_type) {
  1339. case LLAMA_POOLING_TYPE_NONE:
  1340. {
  1341. // extract token embeddings
  1342. GGML_ASSERT(embd != nullptr);
  1343. const uint32_t n_embd_out = hparams.get_n_embd_out();
  1344. float * embd_out = embd + n_outputs_prev*n_embd_out;
  1345. if (n_outputs) {
  1346. GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
  1347. GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
  1348. ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
  1349. }
  1350. } break;
  1351. case LLAMA_POOLING_TYPE_MEAN:
  1352. case LLAMA_POOLING_TYPE_CLS:
  1353. case LLAMA_POOLING_TYPE_LAST:
  1354. {
  1355. // extract sequence embeddings (cleared before processing each batch)
  1356. auto & embd_seq_out = embd_seq;
  1357. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  1358. const llama_seq_id seq_id = ubatch.seq_id_unq[s];
  1359. const int32_t seq_idx = ubatch.seq_idx[seq_id];
  1360. embd_seq_out[seq_id].resize(n_embd);
  1361. ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
  1362. }
  1363. } break;
  1364. case LLAMA_POOLING_TYPE_RANK:
  1365. {
  1366. // extract the rerank score - n_cls_out floats per sequence
  1367. auto & embd_seq_out = embd_seq;
  1368. const uint32_t n_cls_out = hparams.n_cls_out;
  1369. for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
  1370. const llama_seq_id seq_id = ubatch.seq_id_unq[s];
  1371. const int32_t seq_idx = ubatch.seq_idx[seq_id];
  1372. embd_seq_out[seq_id].resize(n_cls_out);
  1373. ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
  1374. }
  1375. } break;
  1376. case LLAMA_POOLING_TYPE_UNSPECIFIED:
  1377. {
  1378. GGML_ABORT("unknown pooling type");
  1379. }
  1380. }
  1381. }
  1382. // This flag indicates whether a backend sampler has actually sampled a specific
  1383. // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
  1384. const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
  1385. if (has_samplers && has_sampled) {
  1386. const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
  1387. const auto stride = n_vocab;
  1388. // async copy the sampling data from the backend to the host
  1389. copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
  1390. copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
  1391. copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
  1392. copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
  1393. }
  1394. n_outputs_prev += n_outputs;
  1395. } while (mctx->next());
  1396. // set to total number of outputs in the batch, for use in llama_get_logits_ith
  1397. n_outputs = n_outputs_all;
  1398. // set output mappings
  1399. if (n_outputs > 0) {
  1400. bool sorted_output = true;
  1401. auto & out_ids = balloc->get_out_ids();
  1402. GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
  1403. for (int64_t i = 0; i < n_outputs; ++i) {
  1404. int64_t out_id = out_ids[i];
  1405. output_ids[out_id] = i;
  1406. if (out_id != i) {
  1407. sorted_output = false;
  1408. }
  1409. }
  1410. // make the outputs have the same order they had in the user-provided batch
  1411. // note: this is mostly relevant for recurrent models atm
  1412. if (!sorted_output && n_outputs > 1) {
  1413. GGML_ASSERT((size_t) n_outputs == out_ids.size());
  1414. // TODO: is there something more efficient which also minimizes swaps?
  1415. // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
  1416. for (uint32_t i = 0; i < n_outputs - 1; ++i) {
  1417. uint32_t j_min = i;
  1418. for (uint32_t j = i + 1; j < n_outputs; ++j) {
  1419. if (out_ids[j] < out_ids[j_min]) {
  1420. j_min = j;
  1421. }
  1422. }
  1423. if (j_min == i) {
  1424. continue;
  1425. }
  1426. std::swap(out_ids[i], out_ids[j_min]);
  1427. // remember the swaps and apply them lazily upon logits/embeddings access
  1428. output_swaps.push_back({ i, j_min });
  1429. }
  1430. std::fill(output_ids.begin(), output_ids.end(), -1);
  1431. for (uint32_t i = 0; i < n_outputs; ++i) {
  1432. output_ids[out_ids[i]] = i;
  1433. }
  1434. }
  1435. }
  1436. // wait for the computation to finish (automatically done when obtaining the model output)
  1437. //synchronize();
  1438. return 0;
  1439. }
  1440. //
  1441. // output
  1442. //
  1443. uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
  1444. const auto & hparams = model.hparams;
  1445. const auto & vocab = model.vocab;
  1446. const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
  1447. const auto n_batch = cparams.n_batch;
  1448. const auto n_vocab = vocab.n_tokens();
  1449. const auto n_embd_out = hparams.get_n_embd_out();
  1450. bool has_logits = true;
  1451. bool has_embd = cparams.embeddings;
  1452. // TODO: hacky enc-dec support
  1453. if (model.arch == LLM_ARCH_T5) {
  1454. has_logits = true;
  1455. has_embd = true;
  1456. }
  1457. // Check which sampling modes are needed for the current batch.
  1458. // TODO: avoid this branching by working with the worst-case
  1459. bool has_sampling = false;
  1460. bool cpu_logits = false;
  1461. if (batch.logits) {
  1462. for (int32_t i = 0; i < batch.n_tokens; i++) {
  1463. if (!batch.logits[i]) {
  1464. continue;
  1465. }
  1466. for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
  1467. llama_seq_id seq_id = batch.seq_id[i][j];
  1468. if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
  1469. has_sampling = true;
  1470. } else {
  1471. cpu_logits = true;
  1472. }
  1473. }
  1474. }
  1475. } else {
  1476. // When batch.logits is nullptr (when loading state with a dummy batch),
  1477. // allocate CPU logits.
  1478. cpu_logits = true;
  1479. }
  1480. size_t backend_float_count = 0;
  1481. size_t backend_token_count = 0;
  1482. // Allocate CPU logits buffer only if needed by sequences in this batch
  1483. logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
  1484. embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
  1485. // TODO: avoid this branching by working with the worst-case
  1486. if (!has_sampling) {
  1487. sampling.logits_size = 0;
  1488. sampling.probs_size = 0;
  1489. sampling.sampled_size = 0;
  1490. sampling.candidates_size = 0;
  1491. } else {
  1492. sampling.logits_size = n_vocab*n_outputs_max;
  1493. sampling.probs_size = n_vocab*n_outputs_max;
  1494. sampling.sampled_size = n_outputs_max;
  1495. sampling.candidates_size = n_vocab*n_outputs_max;
  1496. backend_float_count = sampling.logits_size + sampling.probs_size;
  1497. backend_token_count = sampling.sampled_size + sampling.candidates_size;
  1498. }
  1499. if (output_ids.empty()) {
  1500. // init, never resized afterwards
  1501. output_ids.resize(n_batch);
  1502. }
  1503. const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
  1504. const size_t new_size =
  1505. (logits_size + embd_size + backend_float_count) * sizeof(float) +
  1506. ( backend_token_count) * sizeof(llama_token);
  1507. // alloc only when more than the current capacity is required
  1508. // TODO: also consider shrinking the buffer
  1509. if (!buf_output || prev_size < new_size) {
  1510. if (buf_output) {
  1511. #ifndef NDEBUG
  1512. // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
  1513. LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
  1514. #endif
  1515. synchronize();
  1516. // TODO: not needed?
  1517. buf_output = nullptr;
  1518. logits = nullptr;
  1519. embd = nullptr;
  1520. }
  1521. auto * buft = ggml_backend_cpu_buffer_type();
  1522. // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
  1523. auto * output_dev = model.dev_output();
  1524. auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
  1525. if (output_dev_host_buft) {
  1526. buft = output_dev_host_buft;
  1527. }
  1528. buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
  1529. if (buf_output == nullptr) {
  1530. LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
  1531. return 0;
  1532. }
  1533. }
  1534. float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
  1535. logits = nullptr;
  1536. embd = nullptr;
  1537. size_t offset = 0;
  1538. uint8_t * base = (uint8_t *) output_base;
  1539. logits = (has_logits && cpu_logits) ? output_base : nullptr;
  1540. offset += logits_size * sizeof(float);
  1541. embd = has_embd ? (float *) (base + offset) : nullptr;
  1542. offset += embd_size * sizeof(float);
  1543. sampling.logits = nullptr;
  1544. sampling.probs = nullptr;
  1545. sampling.sampled = nullptr;
  1546. sampling.candidates = nullptr;
  1547. if (has_sampling) {
  1548. sampling.logits = (float *) (base + offset);
  1549. offset += sampling.logits_size * sizeof(float);
  1550. sampling.probs = (float *) (base + offset);
  1551. offset += sampling.probs_size * sizeof(float);
  1552. sampling.sampled = (llama_token *) (base + offset);
  1553. offset += sampling.sampled_size * sizeof(llama_token);
  1554. sampling.candidates = (llama_token *) (base + offset);
  1555. offset += sampling.candidates_size * sizeof(llama_token);
  1556. // The count vectors keep track of the actual number of logits/probs/candidates
  1557. // copied from the backend for each output row.
  1558. sampling.logits_count.resize(n_outputs_max);
  1559. sampling.probs_count.resize(n_outputs_max);
  1560. sampling.candidates_count.resize(n_outputs_max);
  1561. std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
  1562. std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
  1563. std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
  1564. std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
  1565. }
  1566. // set all ids as invalid (negative)
  1567. std::fill(output_ids.begin(), output_ids.end(), -1);
  1568. this->n_outputs = 0;
  1569. return n_outputs_max;
  1570. }
  1571. void llama_context::output_reorder() {
  1572. const uint64_t n_vocab = model.vocab.n_tokens();
  1573. const uint64_t n_embd = model.hparams.n_embd;
  1574. for (size_t s = 0; s < output_swaps.size(); ++s) {
  1575. const uint64_t i0 = output_swaps[s].i0;
  1576. const uint64_t i1 = output_swaps[s].i1;
  1577. if (logits_size > 0) {
  1578. for (uint64_t k = 0; k < n_vocab; k++) {
  1579. std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
  1580. }
  1581. }
  1582. if (embd_size > 0) {
  1583. for (uint64_t k = 0; k < n_embd; k++) {
  1584. std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
  1585. }
  1586. }
  1587. if (sampling.logits && sampling.logits_size > 0) {
  1588. for (uint64_t k = 0; k < n_vocab; ++k) {
  1589. std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
  1590. }
  1591. }
  1592. if (sampling.probs && sampling.probs_size > 0) {
  1593. for (uint64_t k = 0; k < n_vocab; ++k) {
  1594. std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
  1595. }
  1596. }
  1597. if (sampling.candidates && sampling.candidates_size > 0) {
  1598. for (uint64_t k = 0; k < n_vocab; ++k) {
  1599. std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
  1600. }
  1601. }
  1602. if (sampling.sampled && sampling.sampled_size > 0) {
  1603. std::swap(sampling.sampled[i0], sampling.sampled[i1]);
  1604. }
  1605. if (!sampling.logits_count.empty()) {
  1606. std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
  1607. }
  1608. if (!sampling.probs_count.empty()) {
  1609. std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
  1610. }
  1611. if (!sampling.candidates_count.empty()) {
  1612. std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
  1613. }
  1614. }
  1615. output_swaps.clear();
  1616. }
  1617. //
  1618. // graph
  1619. //
  1620. uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
  1621. if (model.arch == LLM_ARCH_QWEN3NEXT) {
  1622. return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
  1623. }
  1624. uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
  1625. for (const auto & lora : model.loras) {
  1626. res += lora->get_n_nodes();
  1627. }
  1628. return res;
  1629. }
  1630. llm_graph_result * llama_context::get_gf_res_reserve() const {
  1631. return static_cast<llm_graph_result *>(gf_res_reserve.get());
  1632. }
  1633. ggml_cgraph * llama_context::graph_reserve(
  1634. uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
  1635. LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
  1636. GGML_ASSERT(n_outputs >= 1);
  1637. if (n_tokens % n_seqs != 0) {
  1638. n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
  1639. n_outputs = std::max(n_outputs, n_tokens);
  1640. LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
  1641. }
  1642. ggml_backend_sched_reset(sched.get());
  1643. // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
  1644. gf_res_prev->reset();
  1645. // store the n_outputs as it is, and restore it afterwards
  1646. // TODO: not sure if needed, might simplify in the future by removing this
  1647. const auto save_n_outputs = this->n_outputs;
  1648. this->n_outputs = n_outputs;
  1649. llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
  1650. llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
  1651. // set one output token per sequence in order to activate all backend samplers
  1652. std::vector<llama_seq_id> seq_ids(n_seqs);
  1653. for (uint32_t i = 0; i < n_seqs; ++i) {
  1654. seq_ids[i] = i;
  1655. ubatch.n_seq_id[i] = 1;
  1656. ubatch.seq_id[i] = &seq_ids[i];
  1657. ubatch.output[i] = true;
  1658. }
  1659. auto * res = gf_res_reserve.get();
  1660. const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
  1661. res->reset();
  1662. auto * gf = model.build_graph(gparams);
  1663. this->n_outputs = save_n_outputs;
  1664. // initialize scheduler with the specified graph
  1665. if (split_only) {
  1666. if (sizes) {
  1667. ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
  1668. } else {
  1669. ggml_backend_sched_split_graph(sched.get(), gf);
  1670. }
  1671. } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
  1672. GGML_ASSERT(!sizes);
  1673. LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
  1674. return nullptr;
  1675. }
  1676. return gf;
  1677. }
  1678. llm_graph_params llama_context::graph_params(
  1679. llm_graph_result * res,
  1680. const llama_ubatch & ubatch,
  1681. const llama_memory_context_i * mctx,
  1682. llm_graph_type gtype) const {
  1683. return {
  1684. /*.arch =*/ model.arch,
  1685. /*.hparams =*/ model.hparams,
  1686. /*.cparams =*/ cparams,
  1687. /*.ubatch =*/ ubatch,
  1688. /*.gtype =*/ gtype,
  1689. /*.sched =*/ sched.get(),
  1690. /*.backend_cpu =*/ backend_cpu,
  1691. /*.cvec =*/ &cvec,
  1692. /*.loras =*/ &loras,
  1693. /*.mctx =*/ mctx,
  1694. /*.cross =*/ &cross,
  1695. /*.samplers =*/ sampling.samplers,
  1696. /*.n_outputs =*/ n_outputs,
  1697. /*.cb =*/ graph_get_cb(),
  1698. /*.res =*/ res,
  1699. };
  1700. }
  1701. ggml_status llama_context::graph_compute(
  1702. ggml_cgraph * gf,
  1703. bool batched) {
  1704. int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
  1705. ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
  1706. if (backend_cpu != nullptr) {
  1707. auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
  1708. auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
  1709. if (set_threadpool_fn) {
  1710. set_threadpool_fn(backend_cpu, tp);
  1711. }
  1712. }
  1713. // set the number of threads for all the backends
  1714. for (const auto & set_n_threads_fn : set_n_threads_fns) {
  1715. set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
  1716. }
  1717. auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf);
  1718. if (status != GGML_STATUS_SUCCESS) {
  1719. LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
  1720. }
  1721. // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
  1722. return status;
  1723. }
  1724. llm_graph_cb llama_context::graph_get_cb() const {
  1725. return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
  1726. if (il >= 0) {
  1727. ggml_format_name(cur, "%s-%d", name, il);
  1728. } else {
  1729. ggml_set_name(cur, name);
  1730. }
  1731. if (!cparams.offload_kqv) {
  1732. if (strcmp(name, "kqv_merged_cont") == 0) {
  1733. // all nodes between the KV store and the attention output are run on the CPU
  1734. ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
  1735. }
  1736. }
  1737. // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
  1738. // FIXME: fix in ggml_backend_sched
  1739. const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
  1740. if (ubatch.n_tokens < 32 || full_offload) {
  1741. if (il != -1 && strcmp(name, "norm") == 0) {
  1742. const auto & dev_layer = model.dev_layer(il);
  1743. for (const auto & backend : backends) {
  1744. if (ggml_backend_get_device(backend.get()) == dev_layer) {
  1745. if (ggml_backend_supports_op(backend.get(), cur)) {
  1746. ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
  1747. }
  1748. }
  1749. }
  1750. }
  1751. }
  1752. };
  1753. }
  1754. //
  1755. // state save/load
  1756. //
  1757. class llama_io_write_dummy : public llama_io_write_i {
  1758. public:
  1759. llama_io_write_dummy() = default;
  1760. void write(const void * /* src */, size_t size) override {
  1761. size_written += size;
  1762. }
  1763. void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
  1764. size_written += size;
  1765. }
  1766. size_t n_bytes() override {
  1767. return size_written;
  1768. }
  1769. private:
  1770. size_t size_written = 0;
  1771. };
  1772. class llama_io_write_buffer : public llama_io_write_i {
  1773. public:
  1774. llama_io_write_buffer(
  1775. uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
  1776. void write(const void * src, size_t size) override {
  1777. if (size > buf_size) {
  1778. throw std::runtime_error("unexpectedly reached end of buffer");
  1779. }
  1780. memcpy(ptr, src, size);
  1781. ptr += size;
  1782. size_written += size;
  1783. buf_size -= size;
  1784. }
  1785. void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
  1786. if (size > buf_size) {
  1787. throw std::runtime_error("unexpectedly reached end of buffer");
  1788. }
  1789. ggml_backend_tensor_get(tensor, ptr, offset, size);
  1790. ptr += size;
  1791. size_written += size;
  1792. buf_size -= size;
  1793. }
  1794. size_t n_bytes() override {
  1795. return size_written;
  1796. }
  1797. private:
  1798. uint8_t * ptr;
  1799. size_t buf_size = 0;
  1800. size_t size_written = 0;
  1801. };
  1802. class llama_io_read_buffer : public llama_io_read_i {
  1803. public:
  1804. llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
  1805. const uint8_t * read(size_t size) override {
  1806. const uint8_t * base_ptr = ptr;
  1807. if (size > buf_size) {
  1808. throw std::runtime_error("unexpectedly reached end of buffer");
  1809. }
  1810. ptr += size;
  1811. size_read += size;
  1812. buf_size -= size;
  1813. return base_ptr;
  1814. }
  1815. void read_to(void * dst, size_t size) override {
  1816. memcpy(dst, read(size), size);
  1817. }
  1818. size_t n_bytes() override {
  1819. return size_read;
  1820. }
  1821. private:
  1822. const uint8_t * ptr;
  1823. size_t buf_size = 0;
  1824. size_t size_read = 0;
  1825. };
  1826. class llama_io_write_file : public llama_io_write_i {
  1827. public:
  1828. llama_io_write_file(llama_file * f) : file(f) {}
  1829. void write(const void * src, size_t size) override {
  1830. file->write_raw(src, size);
  1831. size_written += size;
  1832. }
  1833. void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
  1834. temp_buffer.resize(size);
  1835. ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
  1836. write(temp_buffer.data(), temp_buffer.size());
  1837. }
  1838. size_t n_bytes() override {
  1839. return size_written;
  1840. }
  1841. private:
  1842. llama_file * file;
  1843. size_t size_written = 0;
  1844. std::vector<uint8_t> temp_buffer;
  1845. };
  1846. class llama_io_read_file : public llama_io_read_i {
  1847. public:
  1848. llama_io_read_file(llama_file * f) : file(f) {}
  1849. void read_to(void * dst, size_t size) override {
  1850. file->read_raw(dst, size);
  1851. size_read += size;
  1852. }
  1853. const uint8_t * read(size_t size) override {
  1854. temp_buffer.resize(size);
  1855. read_to(temp_buffer.data(), size);
  1856. return temp_buffer.data();
  1857. }
  1858. size_t n_bytes() override {
  1859. return size_read;
  1860. }
  1861. private:
  1862. llama_file * file;
  1863. size_t size_read = 0;
  1864. std::vector<uint8_t> temp_buffer;
  1865. };
  1866. size_t llama_context::state_get_size() {
  1867. llama_io_write_dummy io;
  1868. try {
  1869. return state_write_data(io);
  1870. } catch (const std::exception & err) {
  1871. LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
  1872. return 0;
  1873. }
  1874. }
  1875. size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
  1876. llama_io_write_buffer io(dst, size);
  1877. try {
  1878. return state_write_data(io);
  1879. } catch (const std::exception & err) {
  1880. LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
  1881. return 0;
  1882. }
  1883. }
  1884. size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
  1885. llama_io_read_buffer io(src, size);
  1886. try {
  1887. return state_read_data(io);
  1888. } catch (const std::exception & err) {
  1889. LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
  1890. return 0;
  1891. }
  1892. }
  1893. size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
  1894. llama_io_write_dummy io;
  1895. try {
  1896. return state_seq_write_data(io, seq_id, flags);
  1897. } catch (const std::exception & err) {
  1898. LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
  1899. return 0;
  1900. }
  1901. }
  1902. size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
  1903. llama_io_write_buffer io(dst, size);
  1904. try {
  1905. return state_seq_write_data(io, seq_id, flags);
  1906. } catch (const std::exception & err) {
  1907. LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
  1908. return 0;
  1909. }
  1910. }
  1911. size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
  1912. llama_io_read_buffer io(src, size);
  1913. try {
  1914. return state_seq_read_data(io, seq_id, flags);
  1915. } catch (const std::exception & err) {
  1916. LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
  1917. return 0;
  1918. }
  1919. }
  1920. bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
  1921. llama_file file(filepath, "rb");
  1922. // sanity checks
  1923. {
  1924. const uint32_t magic = file.read_u32();
  1925. const uint32_t version = file.read_u32();
  1926. if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
  1927. LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
  1928. return false;
  1929. }
  1930. }
  1931. // load the prompt
  1932. {
  1933. const uint32_t n_token_count = file.read_u32();
  1934. if (n_token_count > n_token_capacity) {
  1935. LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
  1936. return false;
  1937. }
  1938. file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
  1939. *n_token_count_out = n_token_count;
  1940. }
  1941. // restore the context state
  1942. {
  1943. const size_t n_state_size_cur = file.size() - file.tell();
  1944. llama_io_read_file io( &file);
  1945. const size_t n_read = state_read_data(io);
  1946. if (n_read != n_state_size_cur) {
  1947. LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
  1948. return false;
  1949. }
  1950. }
  1951. return true;
  1952. }
  1953. bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
  1954. llama_file file(filepath, "wb");
  1955. file.write_u32(LLAMA_SESSION_MAGIC);
  1956. file.write_u32(LLAMA_SESSION_VERSION);
  1957. // save the prompt
  1958. file.write_u32((uint32_t) n_token_count);
  1959. file.write_raw(tokens, sizeof(llama_token) * n_token_count);
  1960. // save the context state using stream saving
  1961. llama_io_write_file io(&file);
  1962. state_write_data(io);
  1963. return true;
  1964. }
  1965. size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
  1966. llama_file file(filepath, "rb");
  1967. // version checks
  1968. {
  1969. const uint32_t magic = file.read_u32();
  1970. const uint32_t version = file.read_u32();
  1971. if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
  1972. LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
  1973. return 0;
  1974. }
  1975. }
  1976. // load the prompt
  1977. {
  1978. const uint32_t n_token_count = file.read_u32();
  1979. if (n_token_count > n_token_capacity) {
  1980. LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
  1981. return 0;
  1982. }
  1983. file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
  1984. *n_token_count_out = n_token_count;
  1985. }
  1986. // restore the context state
  1987. {
  1988. const size_t state_size = file.size() - file.tell();
  1989. llama_io_read_file io(&file);
  1990. const size_t nread = state_seq_read_data(io, seq_id, 0);
  1991. if (!nread) {
  1992. LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
  1993. return 0;
  1994. }
  1995. GGML_ASSERT(nread <= state_size);
  1996. GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
  1997. }
  1998. return file.tell();
  1999. }
  2000. size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
  2001. llama_file file(filepath, "wb");
  2002. file.write_u32(LLAMA_STATE_SEQ_MAGIC);
  2003. file.write_u32(LLAMA_STATE_SEQ_VERSION);
  2004. // save the prompt
  2005. file.write_u32((uint32_t) n_token_count);
  2006. file.write_raw(tokens, sizeof(llama_token) * n_token_count);
  2007. // save the context state using stream saving
  2008. llama_io_write_file io(&file);
  2009. state_seq_write_data(io, seq_id, 0);
  2010. const size_t res = file.tell();
  2011. GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
  2012. return res;
  2013. }
  2014. size_t llama_context::state_write_data(llama_io_write_i & io) {
  2015. LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
  2016. // write model info
  2017. {
  2018. LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
  2019. const std::string arch_str = llm_arch_name(model.arch);
  2020. io.write_string(arch_str);
  2021. // TODO: add more model-specific info which should prevent loading the session file if not identical
  2022. }
  2023. // write output ids
  2024. {
  2025. LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
  2026. const auto n_outputs = this->n_outputs;
  2027. const auto & output_ids = this->output_ids;
  2028. std::vector<int32_t> w_output_pos;
  2029. w_output_pos.resize(n_outputs);
  2030. // build a more compact representation of the output ids
  2031. for (size_t i = 0; i < n_batch(); ++i) {
  2032. // map an output id to a position in the batch
  2033. int64_t pos = output_ids[i];
  2034. if (pos >= 0) {
  2035. GGML_ASSERT(pos < n_outputs);
  2036. w_output_pos[pos] = i;
  2037. }
  2038. }
  2039. io.write(&n_outputs, sizeof(n_outputs));
  2040. if (n_outputs) {
  2041. io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
  2042. }
  2043. }
  2044. // write logits
  2045. {
  2046. LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
  2047. const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
  2048. io.write(&logits_size, sizeof(logits_size));
  2049. if (logits_size) {
  2050. io.write(logits, logits_size * sizeof(float));
  2051. }
  2052. }
  2053. // write embeddings
  2054. {
  2055. LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
  2056. const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
  2057. io.write(&embd_size, sizeof(embd_size));
  2058. if (embd_size) {
  2059. io.write(embd, embd_size * sizeof(float));
  2060. }
  2061. }
  2062. // TODO: handle sampling buffers and samplers state ?
  2063. // https://github.com/ggml-org/llama.cpp/pull/17004
  2064. if (memory != nullptr) {
  2065. LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
  2066. memory->state_write(io);
  2067. }
  2068. return io.n_bytes();
  2069. }
  2070. size_t llama_context::state_read_data(llama_io_read_i & io) {
  2071. LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
  2072. // read model info
  2073. {
  2074. LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
  2075. const std::string cur_arch_str = llm_arch_name(model.arch);
  2076. std::string arch_str;
  2077. io.read_string(arch_str);
  2078. if (cur_arch_str != arch_str) {
  2079. throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
  2080. }
  2081. // TODO: add more info which needs to be identical but which is not verified otherwise
  2082. }
  2083. // read output ids
  2084. {
  2085. LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
  2086. auto n_outputs = this->n_outputs;
  2087. io.read_to(&n_outputs, sizeof(n_outputs));
  2088. // Create a dummy batch for state loading.
  2089. llama_batch dummy_batch = {};
  2090. dummy_batch.n_tokens = 0;
  2091. if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
  2092. throw std::runtime_error("could not reserve outputs");
  2093. }
  2094. std::vector<int32_t> output_pos;
  2095. if (n_outputs) {
  2096. output_pos.resize(n_outputs);
  2097. io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
  2098. for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
  2099. int32_t id = output_pos[i];
  2100. if ((uint32_t) id >= n_batch()) {
  2101. throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
  2102. }
  2103. this->output_ids[id] = i;
  2104. }
  2105. this->n_outputs = n_outputs;
  2106. }
  2107. }
  2108. // read logits
  2109. {
  2110. LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
  2111. uint64_t logits_size;
  2112. io.read_to(&logits_size, sizeof(logits_size));
  2113. if (this->logits_size < logits_size) {
  2114. throw std::runtime_error("logits buffer too small");
  2115. }
  2116. if (logits_size) {
  2117. io.read_to(this->logits, logits_size * sizeof(float));
  2118. }
  2119. }
  2120. // read embeddings
  2121. {
  2122. LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
  2123. uint64_t embd_size;
  2124. io.read_to(&embd_size, sizeof(embd_size));
  2125. if (this->embd_size < embd_size) {
  2126. throw std::runtime_error("embeddings buffer too small");
  2127. }
  2128. if (embd_size) {
  2129. io.read_to(this->embd, embd_size * sizeof(float));
  2130. }
  2131. }
  2132. // TODO: handle sampling buffers and samplers state ?
  2133. // https://github.com/ggml-org/llama.cpp/pull/17004
  2134. if (memory) {
  2135. LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
  2136. memory->state_read(io);
  2137. }
  2138. return io.n_bytes();
  2139. }
  2140. size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
  2141. GGML_UNUSED(seq_id);
  2142. if (memory) {
  2143. memory->state_write(io, seq_id, flags);
  2144. }
  2145. return io.n_bytes();
  2146. }
  2147. size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
  2148. GGML_UNUSED(seq_id);
  2149. if (memory) {
  2150. memory->state_read(io, seq_id, flags);
  2151. }
  2152. return io.n_bytes();
  2153. }
  2154. //
  2155. // perf
  2156. //
  2157. llama_perf_context_data llama_context::perf_get_data() const {
  2158. llama_perf_context_data data = {};
  2159. data.t_start_ms = 1e-3 * t_start_us;
  2160. data.t_load_ms = 1e-3 * t_load_us;
  2161. data.t_p_eval_ms = 1e-3 * t_p_eval_us;
  2162. data.t_eval_ms = 1e-3 * t_eval_us;
  2163. data.n_p_eval = std::max(1, n_p_eval);
  2164. data.n_eval = std::max(1, n_eval);
  2165. data.n_reused = std::max(0, n_reused);
  2166. return data;
  2167. }
  2168. void llama_context::perf_reset() {
  2169. t_start_us = ggml_time_us();
  2170. t_eval_us = n_eval = 0;
  2171. t_p_eval_us = n_p_eval = 0;
  2172. n_reused = 0;
  2173. }
  2174. std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
  2175. std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
  2176. for (const auto & [buft, size] : model.memory_breakdown()) {
  2177. ret[buft].model += size;
  2178. }
  2179. if (memory) {
  2180. for (const auto & [buft, size] : memory->memory_breakdown()) {
  2181. ret[buft].context += size;
  2182. }
  2183. }
  2184. if (model.hparams.no_alloc) {
  2185. for (size_t i = 0; i < backends.size(); ++i) {
  2186. ggml_backend_t backend = backends[i].get();
  2187. ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
  2188. ret[buft].compute += backend_buf_exp_size[i];
  2189. }
  2190. } else {
  2191. for (const auto & backend_ptr : backends) {
  2192. ggml_backend_t backend = backend_ptr.get();
  2193. ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
  2194. ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
  2195. }
  2196. }
  2197. return ret;
  2198. }
  2199. //
  2200. // training
  2201. //
  2202. static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
  2203. if (!tensor || tensor->type != GGML_TYPE_F32) {
  2204. return;
  2205. }
  2206. if (!param_filter(tensor, userdata)) {
  2207. return;
  2208. }
  2209. if (strcmp(tensor->name, "token_embd.weight") == 0) {
  2210. return; // FIXME
  2211. }
  2212. if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
  2213. return; // FIXME
  2214. }
  2215. ggml_set_param(tensor);
  2216. }
  2217. void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
  2218. GGML_ASSERT(!opt_ctx);
  2219. model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
  2220. const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
  2221. const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
  2222. GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
  2223. GGML_ASSERT(n_batch % n_ubatch == 0);
  2224. ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
  2225. opt_params.opt_period = n_batch / n_ubatch;
  2226. opt_params.get_opt_pars = lopt_params.get_opt_pars;
  2227. opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
  2228. opt_params.optimizer = lopt_params.optimizer_type;
  2229. opt_ctx = ggml_opt_init(opt_params);
  2230. llama_opt_param_filter param_filter = lopt_params.param_filter;
  2231. void * param_filter_ud = lopt_params.param_filter_ud;
  2232. //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
  2233. llama_set_param(model->type_embd, param_filter, param_filter_ud);
  2234. llama_set_param(model->pos_embd, param_filter, param_filter_ud);
  2235. llama_set_param(model->tok_norm, param_filter, param_filter_ud);
  2236. llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
  2237. llama_set_param(model->output_norm, param_filter, param_filter_ud);
  2238. llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
  2239. llama_set_param(model->output, param_filter, param_filter_ud);
  2240. llama_set_param(model->output_b, param_filter, param_filter_ud);
  2241. llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
  2242. llama_set_param(model->cls, param_filter, param_filter_ud);
  2243. llama_set_param(model->cls_b, param_filter, param_filter_ud);
  2244. llama_set_param(model->cls_out, param_filter, param_filter_ud);
  2245. llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
  2246. for (struct llama_layer & layer : model->layers) {
  2247. for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
  2248. llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
  2249. }
  2250. }
  2251. }
  2252. void llama_context::opt_epoch_iter(
  2253. ggml_opt_dataset_t dataset,
  2254. ggml_opt_result_t result,
  2255. const std::vector<llama_token> & tokens,
  2256. const std::vector<llama_token> & labels_sparse,
  2257. llama_batch & batch,
  2258. ggml_opt_epoch_callback callback,
  2259. bool train,
  2260. int64_t idata_in_loop,
  2261. int64_t ndata_in_loop,
  2262. int64_t t_loop_start) {
  2263. GGML_ASSERT(opt_ctx);
  2264. const uint32_t n_ctx = llama_model_n_ctx_train(&model);
  2265. const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
  2266. const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
  2267. memory->clear(true);
  2268. for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
  2269. batch.n_tokens = n_batch;
  2270. for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
  2271. batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
  2272. batch.pos [pos_batch] = pos_ctx + pos_batch;
  2273. batch.n_seq_id[pos_batch] = 1;
  2274. batch.seq_id [pos_batch][0] = 0;
  2275. batch.logits [pos_batch] = true;
  2276. }
  2277. if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
  2278. LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
  2279. return;
  2280. }
  2281. const uint32_t n_tokens_all = balloc->get_n_tokens();
  2282. n_queued_tokens += n_tokens_all;
  2283. embd_seq.clear();
  2284. uint32_t n_outputs_all = n_tokens_all;
  2285. auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
  2286. if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
  2287. LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
  2288. break;
  2289. }
  2290. // reserve output buffer
  2291. if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
  2292. LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
  2293. GGML_ABORT("TODO: handle this error");
  2294. };
  2295. uint32_t pos_batch = 0;
  2296. do {
  2297. const auto & ubatch = mctx->get_ubatch();
  2298. n_outputs = ubatch.n_tokens;
  2299. if (!mctx->apply()) {
  2300. LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
  2301. break;
  2302. }
  2303. auto * res = gf_res_prev.get();
  2304. const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
  2305. res->reset();
  2306. auto * gf = model.build_graph(gparams);
  2307. struct ggml_context * ctx_compute_opt;
  2308. {
  2309. const size_t size_gf = ggml_graph_size(gf);
  2310. const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
  2311. struct ggml_init_params params = {
  2312. /*.mem_size =*/ size_meta,
  2313. /*.mem_buffer =*/ nullptr,
  2314. /*.no_alloc =*/ true,
  2315. };
  2316. ctx_compute_opt = ggml_init(params);
  2317. }
  2318. ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
  2319. ggml_opt_alloc(opt_ctx, train);
  2320. res->set_inputs(&ubatch);
  2321. {
  2322. struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
  2323. GGML_ASSERT(labels->ne[1] == n_ubatch);
  2324. ggml_set_zero(labels);
  2325. const float onef = 1.0f;
  2326. for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
  2327. const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
  2328. GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
  2329. ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
  2330. }
  2331. }
  2332. ggml_opt_eval(opt_ctx, result);
  2333. if (callback) {
  2334. callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
  2335. }
  2336. ggml_free(ctx_compute_opt);
  2337. pos_batch += ubatch.n_tokens;
  2338. } while (mctx->next());
  2339. }
  2340. }
  2341. void llama_context::opt_epoch(
  2342. ggml_opt_dataset_t dataset,
  2343. ggml_opt_result_t result_train,
  2344. ggml_opt_result_t result_eval,
  2345. int64_t idata_split,
  2346. ggml_opt_epoch_callback callback_train,
  2347. ggml_opt_epoch_callback callback_eval) {
  2348. const uint32_t n_ctx = this->n_ctx();
  2349. const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
  2350. const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
  2351. const int64_t ndata = ggml_opt_dataset_ndata(dataset);
  2352. GGML_ASSERT(idata_split >= 0);
  2353. GGML_ASSERT(idata_split <= ndata);
  2354. const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
  2355. struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
  2356. std::vector<llama_token> tokens(n_ctx);
  2357. std::vector<llama_token> labels_sparse(n_ctx);
  2358. int64_t idata = 0;
  2359. int64_t t_loop_start = ggml_time_us();
  2360. int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
  2361. for (; idata < idata_split; ++idata) {
  2362. constexpr bool train = true;
  2363. const int64_t idata_in_loop = idata*ubatch_per_ctx;
  2364. ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
  2365. opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
  2366. callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
  2367. }
  2368. t_loop_start = ggml_time_us();
  2369. ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
  2370. for (; idata < ndata; ++idata) {
  2371. constexpr bool train = false;
  2372. const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
  2373. ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
  2374. opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
  2375. callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
  2376. }
  2377. llama_batch_free(batch);
  2378. }
  2379. //
  2380. // interface implementation
  2381. //
  2382. llama_context_params llama_context_default_params() {
  2383. llama_context_params result = {
  2384. /*.n_ctx =*/ 512,
  2385. /*.n_batch =*/ 2048,
  2386. /*.n_ubatch =*/ 512,
  2387. /*.n_seq_max =*/ 1,
  2388. /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
  2389. /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
  2390. /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
  2391. /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
  2392. /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
  2393. /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
  2394. /*.rope_freq_base =*/ 0.0f,
  2395. /*.rope_freq_scale =*/ 0.0f,
  2396. /*.yarn_ext_factor =*/ -1.0f,
  2397. /*.yarn_attn_factor =*/ -1.0f,
  2398. /*.yarn_beta_fast =*/ -1.0f,
  2399. /*.yarn_beta_slow =*/ -1.0f,
  2400. /*.yarn_orig_ctx =*/ 0,
  2401. /*.defrag_thold =*/ -1.0f,
  2402. /*.cb_eval =*/ nullptr,
  2403. /*.cb_eval_user_data =*/ nullptr,
  2404. /*.type_k =*/ GGML_TYPE_F16,
  2405. /*.type_v =*/ GGML_TYPE_F16,
  2406. /*.abort_callback =*/ nullptr,
  2407. /*.abort_callback_data =*/ nullptr,
  2408. /*.embeddings =*/ false,
  2409. /*.offload_kqv =*/ true,
  2410. /*.no_perf =*/ true,
  2411. /*.op_offload =*/ true,
  2412. /*.swa_full =*/ true,
  2413. /*.kv_unified =*/ false,
  2414. /*.sampler =*/ nullptr,
  2415. /*.n_sampler =*/ 0,
  2416. };
  2417. return result;
  2418. }
  2419. llama_context * llama_init_from_model(
  2420. llama_model * model,
  2421. llama_context_params params) {
  2422. if (!model) {
  2423. LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
  2424. return nullptr;
  2425. }
  2426. if (params.n_batch == 0 && params.n_ubatch == 0) {
  2427. LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
  2428. return nullptr;
  2429. }
  2430. if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
  2431. LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
  2432. return nullptr;
  2433. }
  2434. if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
  2435. LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
  2436. params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
  2437. }
  2438. if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
  2439. const uint32_t blck_size = ggml_blck_size(params.type_k);
  2440. if (model->hparams.n_embd_head_k % blck_size != 0) {
  2441. LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
  2442. __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
  2443. return nullptr;
  2444. }
  2445. }
  2446. if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
  2447. const uint32_t blck_size = ggml_blck_size(params.type_v);
  2448. if (model->hparams.n_embd_head_v % blck_size != 0) {
  2449. LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
  2450. __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
  2451. return nullptr;
  2452. }
  2453. }
  2454. if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
  2455. LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
  2456. return nullptr;
  2457. }
  2458. if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
  2459. params.pooling_type != model->hparams.pooling_type) {
  2460. //user-specified pooling-type is different from the model default
  2461. LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
  2462. model->hparams.pooling_type, params.pooling_type);
  2463. }
  2464. try {
  2465. auto * ctx = new llama_context(*model, params);
  2466. return ctx;
  2467. } catch (const std::exception & err) {
  2468. LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
  2469. }
  2470. return nullptr;
  2471. }
  2472. // deprecated
  2473. llama_context * llama_new_context_with_model(
  2474. llama_model * model,
  2475. llama_context_params params) {
  2476. return llama_init_from_model(model, params);
  2477. }
  2478. void llama_free(llama_context * ctx) {
  2479. delete ctx;
  2480. }
  2481. uint32_t llama_n_ctx(const llama_context * ctx) {
  2482. return ctx->n_ctx();
  2483. }
  2484. uint32_t llama_n_ctx_seq(const llama_context * ctx) {
  2485. return ctx->n_ctx_seq();
  2486. }
  2487. uint32_t llama_n_batch(const llama_context * ctx) {
  2488. return ctx->n_batch();
  2489. }
  2490. uint32_t llama_n_ubatch(const llama_context * ctx) {
  2491. return ctx->n_ubatch();
  2492. }
  2493. uint32_t llama_n_seq_max(const llama_context * ctx) {
  2494. return ctx->n_seq_max();
  2495. }
  2496. const llama_model * llama_get_model(const llama_context * ctx) {
  2497. return &ctx->get_model();
  2498. }
  2499. enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
  2500. return ctx->pooling_type();
  2501. }
  2502. void llama_attach_threadpool(
  2503. llama_context * ctx,
  2504. ggml_threadpool_t threadpool,
  2505. ggml_threadpool_t threadpool_batch) {
  2506. ctx->attach_threadpool(threadpool, threadpool_batch);
  2507. }
  2508. void llama_detach_threadpool(llama_context * ctx) {
  2509. ctx->detach_threadpool();
  2510. }
  2511. void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
  2512. ctx->set_n_threads(n_threads, n_threads_batch);
  2513. }
  2514. int32_t llama_n_threads(llama_context * ctx) {
  2515. return ctx->n_threads();
  2516. }
  2517. int32_t llama_n_threads_batch(llama_context * ctx) {
  2518. return ctx->n_threads_batch();
  2519. }
  2520. void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
  2521. ctx->set_abort_callback(abort_callback, abort_callback_data);
  2522. }
  2523. void llama_set_embeddings(llama_context * ctx, bool embeddings) {
  2524. ctx->set_embeddings(embeddings);
  2525. }
  2526. void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
  2527. ctx->set_causal_attn(causal_attn);
  2528. }
  2529. void llama_set_warmup(llama_context * ctx, bool warmup) {
  2530. ctx->set_warmup(warmup);
  2531. }
  2532. void llama_synchronize(llama_context * ctx) {
  2533. ctx->synchronize();
  2534. }
  2535. float * llama_get_logits(llama_context * ctx) {
  2536. ctx->synchronize();
  2537. return ctx->get_logits();
  2538. }
  2539. float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
  2540. ctx->synchronize();
  2541. float * res = nullptr;
  2542. res = ctx->get_sampled_logits_ith(i);
  2543. if (!res) {
  2544. res = ctx->get_logits_ith(i);
  2545. }
  2546. return res;
  2547. }
  2548. float * llama_get_embeddings(llama_context * ctx) {
  2549. ctx->synchronize();
  2550. return ctx->get_embeddings();
  2551. }
  2552. float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
  2553. ctx->synchronize();
  2554. return ctx->get_embeddings_ith(i);
  2555. }
  2556. float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
  2557. ctx->synchronize();
  2558. return ctx->get_embeddings_seq(seq_id);
  2559. }
  2560. bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
  2561. return ctx->set_sampler(seq_id, smpl);
  2562. }
  2563. llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
  2564. ctx->synchronize();
  2565. return ctx->get_sampled_token_ith(i);
  2566. }
  2567. float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
  2568. ctx->synchronize();
  2569. return ctx->get_sampled_probs_ith(i);
  2570. }
  2571. float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
  2572. ctx->synchronize();
  2573. return ctx->get_sampled_logits_ith(i);
  2574. }
  2575. llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
  2576. ctx->synchronize();
  2577. return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
  2578. }
  2579. uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
  2580. ctx->synchronize();
  2581. return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
  2582. }
  2583. uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
  2584. ctx->synchronize();
  2585. return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
  2586. }
  2587. uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
  2588. ctx->synchronize();
  2589. return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
  2590. }
  2591. // llama adapter API
  2592. int32_t llama_set_adapter_lora(
  2593. llama_context * ctx,
  2594. llama_adapter_lora * adapter,
  2595. float scale) {
  2596. ctx->set_adapter_lora(adapter, scale);
  2597. return 0;
  2598. }
  2599. int32_t llama_rm_adapter_lora(
  2600. llama_context * ctx,
  2601. llama_adapter_lora * adapter) {
  2602. bool res = ctx->rm_adapter_lora(adapter);
  2603. return res ? 0 : -1;
  2604. }
  2605. void llama_clear_adapter_lora(llama_context * ctx) {
  2606. ctx->clear_adapter_lora();
  2607. }
  2608. int32_t llama_apply_adapter_cvec(
  2609. llama_context * ctx,
  2610. const float * data,
  2611. size_t len,
  2612. int32_t n_embd,
  2613. int32_t il_start,
  2614. int32_t il_end) {
  2615. bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
  2616. return res ? 0 : -1;
  2617. }
  2618. //
  2619. // memory
  2620. //
  2621. llama_memory_t llama_get_memory(const struct llama_context * ctx) {
  2622. return ctx->get_memory();
  2623. }
  2624. void llama_memory_clear(llama_memory_t mem, bool data) {
  2625. if (!mem) {
  2626. return;
  2627. }
  2628. mem->clear(data);
  2629. }
  2630. bool llama_memory_seq_rm(
  2631. llama_memory_t mem,
  2632. llama_seq_id seq_id,
  2633. llama_pos p0,
  2634. llama_pos p1) {
  2635. if (!mem) {
  2636. return true;
  2637. }
  2638. return mem->seq_rm(seq_id, p0, p1);
  2639. }
  2640. void llama_memory_seq_cp(
  2641. llama_memory_t mem,
  2642. llama_seq_id seq_id_src,
  2643. llama_seq_id seq_id_dst,
  2644. llama_pos p0,
  2645. llama_pos p1) {
  2646. if (!mem) {
  2647. return;
  2648. }
  2649. mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
  2650. }
  2651. void llama_memory_seq_keep(
  2652. llama_memory_t mem,
  2653. llama_seq_id seq_id) {
  2654. if (!mem) {
  2655. return;
  2656. }
  2657. mem->seq_keep(seq_id);
  2658. }
  2659. void llama_memory_seq_add(
  2660. llama_memory_t mem,
  2661. llama_seq_id seq_id,
  2662. llama_pos p0,
  2663. llama_pos p1,
  2664. llama_pos delta) {
  2665. if (!mem) {
  2666. return;
  2667. }
  2668. mem->seq_add(seq_id, p0, p1, delta);
  2669. }
  2670. void llama_memory_seq_div(
  2671. llama_memory_t mem,
  2672. llama_seq_id seq_id,
  2673. llama_pos p0,
  2674. llama_pos p1,
  2675. int d) {
  2676. if (!mem) {
  2677. return;
  2678. }
  2679. mem->seq_div(seq_id, p0, p1, d);
  2680. }
  2681. llama_pos llama_memory_seq_pos_min(
  2682. llama_memory_t mem,
  2683. llama_seq_id seq_id) {
  2684. if (!mem) {
  2685. return -1;
  2686. }
  2687. return mem->seq_pos_min(seq_id);
  2688. }
  2689. llama_pos llama_memory_seq_pos_max(
  2690. llama_memory_t mem,
  2691. llama_seq_id seq_id) {
  2692. if (!mem) {
  2693. return -1;
  2694. }
  2695. return mem->seq_pos_max(seq_id);
  2696. }
  2697. bool llama_memory_can_shift(llama_memory_t mem) {
  2698. if (!mem) {
  2699. return false;
  2700. }
  2701. return mem->get_can_shift();
  2702. }
  2703. // llama state API
  2704. // deprecated
  2705. size_t llama_get_state_size(llama_context * ctx) {
  2706. return llama_state_get_size(ctx);
  2707. }
  2708. // deprecated
  2709. size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
  2710. return llama_state_get_data(ctx, dst, -1);
  2711. }
  2712. // deprecated
  2713. size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
  2714. return llama_state_set_data(ctx, src, -1);
  2715. }
  2716. // deprecated
  2717. bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
  2718. return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
  2719. }
  2720. // deprecated
  2721. bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
  2722. return llama_state_save_file(ctx, path_session, tokens, n_token_count);
  2723. }
  2724. // Returns the *actual* size of the state.
  2725. // Intended to be used when saving to state to a buffer.
  2726. size_t llama_state_get_size(llama_context * ctx) {
  2727. return ctx->state_get_size();
  2728. }
  2729. size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
  2730. ctx->synchronize();
  2731. return ctx->state_get_data(dst, size);
  2732. }
  2733. // Sets the state reading from the specified source address
  2734. size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
  2735. ctx->synchronize();
  2736. return ctx->state_set_data(src, size);
  2737. }
  2738. bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
  2739. ctx->synchronize();
  2740. try {
  2741. return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
  2742. } catch (const std::exception & err) {
  2743. LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
  2744. return false;
  2745. }
  2746. }
  2747. bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
  2748. ctx->synchronize();
  2749. try {
  2750. return ctx->state_save_file(path_session, tokens, n_token_count);
  2751. } catch (const std::exception & err) {
  2752. LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
  2753. return false;
  2754. }
  2755. }
  2756. size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
  2757. return llama_state_seq_get_size_ext(ctx, seq_id, 0);
  2758. }
  2759. size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
  2760. return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
  2761. }
  2762. size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
  2763. return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
  2764. }
  2765. size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
  2766. return ctx->state_seq_get_size(seq_id, flags);
  2767. }
  2768. size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
  2769. ctx->synchronize();
  2770. return ctx->state_seq_get_data(seq_id, dst, size, flags);
  2771. }
  2772. size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
  2773. ctx->synchronize();
  2774. return ctx->state_seq_set_data(seq_id, src, size, flags);
  2775. }
  2776. size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
  2777. ctx->synchronize();
  2778. try {
  2779. return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
  2780. } catch (const std::exception & err) {
  2781. LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
  2782. return 0;
  2783. }
  2784. }
  2785. size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
  2786. ctx->synchronize();
  2787. try {
  2788. return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
  2789. } catch (const std::exception & err) {
  2790. LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
  2791. return 0;
  2792. }
  2793. }
  2794. ///
  2795. int32_t llama_encode(
  2796. llama_context * ctx,
  2797. llama_batch batch) {
  2798. const int ret = ctx->encode(batch);
  2799. if (ret != 0) {
  2800. LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
  2801. }
  2802. return ret;
  2803. }
  2804. int32_t llama_decode(
  2805. llama_context * ctx,
  2806. llama_batch batch) {
  2807. const int ret = ctx->decode(batch);
  2808. if (ret != 0 && ret != 1) {
  2809. LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
  2810. }
  2811. return ret;
  2812. }
  2813. //
  2814. // perf
  2815. //
  2816. llama_perf_context_data llama_perf_context(const llama_context * ctx) {
  2817. llama_perf_context_data data = {};
  2818. if (ctx == nullptr) {
  2819. return data;
  2820. }
  2821. data = ctx->perf_get_data();
  2822. return data;
  2823. }
  2824. void llama_perf_context_print(const llama_context * ctx) {
  2825. const auto data = llama_perf_context(ctx);
  2826. const double t_end_ms = 1e-3 * ggml_time_us();
  2827. LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
  2828. LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
  2829. __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
  2830. LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
  2831. __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
  2832. LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
  2833. LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
  2834. }
  2835. void llama_perf_context_reset(llama_context * ctx) {
  2836. ctx->perf_reset();
  2837. }
  2838. void llama_memory_breakdown_print(const struct llama_context * ctx) {
  2839. const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
  2840. std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
  2841. std::vector<std::array<std::string, 9>> table_data;
  2842. table_data.reserve(devices.size());
  2843. const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
  2844. const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
  2845. const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
  2846. table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
  2847. constexpr size_t MiB = 1024 * 1024;
  2848. const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
  2849. // track seen buffer types to avoid double counting:
  2850. std::set<ggml_backend_buffer_type_t> seen_buffer_types;
  2851. // accumulative memory breakdown for each device and for host:
  2852. std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
  2853. llama_memory_breakdown_data mb_host;
  2854. for (const auto & buft_mb : memory_breakdown) {
  2855. ggml_backend_buffer_type_t buft = buft_mb.first;
  2856. const llama_memory_breakdown_data & mb = buft_mb.second;
  2857. if (ggml_backend_buft_is_host(buft)) {
  2858. mb_host.model += mb.model;
  2859. mb_host.context += mb.context;
  2860. mb_host.compute += mb.compute;
  2861. seen_buffer_types.insert(buft);
  2862. continue;
  2863. }
  2864. ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
  2865. if (dev) {
  2866. int i_dev = -1;
  2867. for (size_t i = 0; i < devices.size(); i++) {
  2868. if (devices[i] == dev) {
  2869. i_dev = i;
  2870. break;
  2871. }
  2872. }
  2873. if (i_dev != -1) {
  2874. mb_dev[i_dev].model += mb.model;
  2875. mb_dev[i_dev].context += mb.context;
  2876. mb_dev[i_dev].compute += mb.compute;
  2877. seen_buffer_types.insert(buft);
  2878. continue;
  2879. }
  2880. }
  2881. }
  2882. // print memory breakdown for each device:
  2883. for (size_t i = 0; i < devices.size(); i++) {
  2884. ggml_backend_dev_t dev = devices[i];
  2885. llama_memory_breakdown_data mb = mb_dev[i];
  2886. const std::string name = ggml_backend_dev_name(dev);
  2887. std::string desc = ggml_backend_dev_description(dev);
  2888. for (const std::string & prefix : desc_prefixes_strip) {
  2889. if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
  2890. desc = desc.substr(prefix.length());
  2891. }
  2892. }
  2893. size_t free, total;
  2894. ggml_backend_dev_memory(dev, &free, &total);
  2895. const size_t self = mb.model + mb.context + mb.compute;
  2896. const size_t unaccounted = total - self - free;
  2897. table_data.push_back({
  2898. template_gpu,
  2899. " - " + name + " (" + desc + ")",
  2900. std::to_string(total / MiB),
  2901. std::to_string(free / MiB),
  2902. std::to_string(self / MiB),
  2903. std::to_string(mb.model / MiB),
  2904. std::to_string(mb.context / MiB),
  2905. std::to_string(mb.compute / MiB),
  2906. std::to_string(unaccounted / MiB)});
  2907. }
  2908. // print memory breakdown for host:
  2909. {
  2910. const size_t self = mb_host.model + mb_host.context + mb_host.compute;
  2911. table_data.push_back({
  2912. template_other,
  2913. " - Host",
  2914. "", // total
  2915. "", // free
  2916. std::to_string(self / MiB),
  2917. std::to_string(mb_host.model / MiB),
  2918. std::to_string(mb_host.context / MiB),
  2919. std::to_string(mb_host.compute / MiB),
  2920. ""}); // unaccounted
  2921. }
  2922. // print memory breakdown for all remaining buffer types:
  2923. for (const auto & buft_mb : memory_breakdown) {
  2924. ggml_backend_buffer_type_t buft = buft_mb.first;
  2925. const llama_memory_breakdown_data & mb = buft_mb.second;
  2926. if (seen_buffer_types.count(buft) == 1) {
  2927. continue;
  2928. }
  2929. const std::string name = ggml_backend_buft_name(buft);
  2930. const size_t self = mb.model + mb.context + mb.compute;
  2931. table_data.push_back({
  2932. template_other,
  2933. " - " + name,
  2934. "", // total
  2935. "", // free
  2936. std::to_string(self / MiB),
  2937. std::to_string(mb.model / MiB),
  2938. std::to_string(mb.context / MiB),
  2939. std::to_string(mb.compute / MiB),
  2940. ""}); // unaccounted
  2941. seen_buffer_types.insert(buft);
  2942. }
  2943. for (size_t j = 1; j < table_data[0].size(); j++) {
  2944. size_t max_len = 0;
  2945. for (const auto & td : table_data) {
  2946. max_len = std::max(max_len, td[j].length());
  2947. }
  2948. for (auto & td : table_data) {
  2949. td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
  2950. }
  2951. }
  2952. for (const auto & td : table_data) {
  2953. LLAMA_LOG_INFO(td[0].c_str(),
  2954. __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
  2955. td[6].c_str(), td[7].c_str(), td[8].c_str());
  2956. }
  2957. }
  2958. //
  2959. // training
  2960. //
  2961. bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
  2962. GGML_UNUSED(tensor);
  2963. GGML_UNUSED(userdata);
  2964. return true;
  2965. }
  2966. void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
  2967. ctx->opt_init(model, lopt_params);
  2968. }
  2969. void llama_opt_epoch(
  2970. struct llama_context * ctx,
  2971. ggml_opt_dataset_t dataset,
  2972. ggml_opt_result_t result_train,
  2973. ggml_opt_result_t result_eval,
  2974. int64_t idata_split,
  2975. ggml_opt_epoch_callback callback_train,
  2976. ggml_opt_epoch_callback callback_eval) {
  2977. ctx->opt_epoch(
  2978. dataset,
  2979. result_train,
  2980. result_eval,
  2981. idata_split,
  2982. callback_train,
  2983. callback_eval);
  2984. }