server-context.cpp 160 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004
  1. #include "server-context.h"
  2. #include "server-common.h"
  3. #include "server-http.h"
  4. #include "server-task.h"
  5. #include "server-queue.h"
  6. #include "common.h"
  7. #include "llama.h"
  8. #include "log.h"
  9. #include "sampling.h"
  10. #include "speculative.h"
  11. #include "mtmd.h"
  12. #include "mtmd-helper.h"
  13. #include <cstddef>
  14. #include <cinttypes>
  15. #include <memory>
  16. #include <filesystem>
  17. // fix problem with std::min and std::max
  18. #if defined(_WIN32)
  19. #define WIN32_LEAN_AND_MEAN
  20. #ifndef NOMINMAX
  21. # define NOMINMAX
  22. #endif
  23. #include <windows.h>
  24. #endif
  25. using json = nlohmann::ordered_json;
  26. constexpr int HTTP_POLLING_SECONDS = 1;
  27. // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
  28. enum slot_state {
  29. SLOT_STATE_IDLE,
  30. SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
  31. SLOT_STATE_STARTED, // after assigning a task and about to process prompt
  32. SLOT_STATE_PROCESSING_PROMPT,
  33. SLOT_STATE_DONE_PROMPT,
  34. SLOT_STATE_GENERATING,
  35. };
  36. enum server_state {
  37. SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
  38. SERVER_STATE_READY, // Server is ready and model is loaded
  39. };
  40. static bool server_task_type_need_embd(server_task_type task_type) {
  41. switch (task_type) {
  42. case SERVER_TASK_TYPE_EMBEDDING:
  43. case SERVER_TASK_TYPE_RERANK:
  44. return true;
  45. default:
  46. return false;
  47. }
  48. }
  49. static bool server_task_type_need_logits(server_task_type task_type) {
  50. switch (task_type) {
  51. case SERVER_TASK_TYPE_COMPLETION:
  52. case SERVER_TASK_TYPE_INFILL:
  53. return true;
  54. default:
  55. return false;
  56. }
  57. }
  58. struct server_slot {
  59. int id;
  60. llama_batch batch_spec = {};
  61. // TODO: change to unique_ptrs for consistency:
  62. llama_context * ctx = nullptr;
  63. llama_context * ctx_dft = nullptr;
  64. // multimodal
  65. mtmd_context * mctx = nullptr;
  66. common_speculative * spec = nullptr;
  67. std::unique_ptr<const server_task> task;
  68. std::unique_ptr<const server_task> task_prev; // used for debugging
  69. // used to determine the slot that has been used the longest
  70. int64_t t_last_used = -1;
  71. // generation props
  72. int32_t n_ctx = 0; // context size per slot
  73. int32_t n_keep = 0;
  74. int32_t n_decoded = 0;
  75. int32_t n_remaining = -1;
  76. int32_t i_batch = -1;
  77. int32_t n_prompt_tokens_cache = 0;
  78. int32_t n_prompt_tokens_processed = 0;
  79. size_t last_nl_pos = 0;
  80. std::string generated_text;
  81. llama_tokens generated_tokens;
  82. // idx of draft tokens in the main batch
  83. // non-empty if we went to evaluate draft tokens
  84. // ref: https://github.com/ggml-org/llama.cpp/pull/17808
  85. std::vector<int32_t> i_batch_dft;
  86. std::vector<completion_token_output> generated_token_probs;
  87. bool has_next_token = true;
  88. bool has_new_line = false;
  89. bool truncated = false;
  90. stop_type stop;
  91. std::string stopping_word;
  92. // state
  93. slot_state state = SLOT_STATE_IDLE;
  94. server_prompt prompt;
  95. void prompt_save(server_prompt_cache & prompt_cache) const {
  96. GGML_ASSERT(prompt.data.size() == 0);
  97. const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
  98. SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
  99. (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
  100. auto * cur = prompt_cache.alloc(prompt, cur_size);
  101. if (cur == nullptr) {
  102. return;
  103. }
  104. llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
  105. }
  106. bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
  107. bool res = prompt_cache.load(prompt, tokens, ctx, id);
  108. if (!res) {
  109. SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
  110. }
  111. return res;
  112. }
  113. std::vector<common_adapter_lora_info> lora;
  114. int32_t alora_invocation_start = -1;
  115. // sampling
  116. json json_schema;
  117. common_sampler_ptr smpl;
  118. llama_token sampled; // in speculative mode, this is the last accepted token
  119. llama_tokens drafted;
  120. // stats
  121. size_t n_sent_text = 0; // number of sent text character
  122. int64_t t_start_process_prompt;
  123. int64_t t_start_generation;
  124. double t_prompt_processing; // ms
  125. double t_token_generation; // ms
  126. std::function<void(int)> callback_on_release;
  127. // Speculative decoding stats
  128. int32_t n_draft_total = 0; // Total draft tokens generated
  129. int32_t n_draft_accepted = 0; // Draft tokens actually accepted
  130. void reset() {
  131. SLT_DBG(*this, "%s", "\n");
  132. n_prompt_tokens_cache = 0;
  133. last_nl_pos = 0;
  134. generated_text = "";
  135. has_new_line = false;
  136. truncated = false;
  137. stop = STOP_TYPE_NONE;
  138. stopping_word = "";
  139. n_sent_text = 0;
  140. drafted.clear();
  141. i_batch_dft.clear();
  142. generated_tokens.clear();
  143. generated_token_probs.clear();
  144. json_schema = json();
  145. // clear speculative decoding stats
  146. n_draft_total = 0;
  147. n_draft_accepted = 0;
  148. task.reset();
  149. task_prev.reset();
  150. // clear alora start
  151. alora_invocation_start = -1;
  152. }
  153. bool need_embd() const {
  154. GGML_ASSERT(task);
  155. return server_task_type_need_embd(task->type);
  156. }
  157. bool need_logits() const {
  158. GGML_ASSERT(task);
  159. return server_task_type_need_logits(task->type);
  160. }
  161. // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
  162. // also we cannot split if the pooling would require any past tokens
  163. bool can_split() const {
  164. return
  165. !need_embd() ||
  166. (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
  167. }
  168. bool can_batch_with(server_slot & other_slot) const {
  169. GGML_ASSERT(task);
  170. return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
  171. }
  172. bool has_budget(const common_params & global_params) {
  173. GGML_ASSERT(task);
  174. if (task->params.n_predict == -1 && global_params.n_predict == -1) {
  175. return true; // limitless
  176. }
  177. n_remaining = -1;
  178. if (task->params.n_predict != -1) {
  179. n_remaining = task->params.n_predict - n_decoded;
  180. } else if (global_params.n_predict != -1) {
  181. n_remaining = global_params.n_predict - n_decoded;
  182. }
  183. return n_remaining > 0; // no budget
  184. }
  185. bool is_processing() const {
  186. return state != SLOT_STATE_IDLE;
  187. }
  188. bool can_speculate() const {
  189. return ctx_dft;
  190. }
  191. void add_token(const completion_token_output & token) {
  192. if (!is_processing()) {
  193. SLT_WRN(*this, "%s", "slot is not processing\n");
  194. return;
  195. }
  196. generated_token_probs.push_back(token);
  197. }
  198. int get_n_draft_max() const {
  199. if (!can_speculate()) {
  200. return 0;
  201. }
  202. // determine the max draft that fits the current slot state
  203. int n_draft_max = task->params.speculative.n_max;
  204. // note: slot.prompt is not yet expanded with the `id` token sampled above
  205. // also, need to leave space for 1 extra token to allow context shifts
  206. n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);
  207. if (n_remaining > 0) {
  208. n_draft_max = std::min(n_draft_max, n_remaining - 1);
  209. }
  210. SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
  211. if (n_draft_max < task->params.speculative.n_min) {
  212. SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
  213. n_draft_max = 0;
  214. }
  215. return n_draft_max;
  216. }
  217. // note: a slot can also be either a parent or a child
  218. bool is_parent() const {
  219. return is_processing() && task->n_children > 0;
  220. }
  221. bool is_child() const {
  222. return is_processing() && task->id_parent >= 0;
  223. }
  224. void release() {
  225. if (is_processing()) {
  226. GGML_ASSERT(task);
  227. SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
  228. t_last_used = ggml_time_us();
  229. t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
  230. state = SLOT_STATE_IDLE;
  231. task_prev = std::move(task);
  232. task.reset();
  233. callback_on_release(id);
  234. }
  235. }
  236. result_timings get_timings() const {
  237. result_timings timings;
  238. timings.cache_n = n_prompt_tokens_cache;
  239. timings.prompt_n = n_prompt_tokens_processed;
  240. timings.prompt_ms = t_prompt_processing;
  241. timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
  242. timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
  243. timings.predicted_n = n_decoded;
  244. timings.predicted_ms = t_token_generation;
  245. timings.predicted_per_token_ms = t_token_generation / n_decoded;
  246. timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
  247. // Add speculative metrics
  248. if (n_draft_total > 0) {
  249. timings.draft_n = n_draft_total;
  250. timings.draft_n_accepted = n_draft_accepted;
  251. }
  252. return timings;
  253. }
  254. size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
  255. GGML_ASSERT(task);
  256. size_t stop_pos = std::string::npos;
  257. for (const std::string & word : task->params.antiprompt) {
  258. size_t pos;
  259. if (is_full_stop) {
  260. const size_t tmp = word.size() + last_token_size;
  261. const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
  262. pos = text.find(word, from_pos);
  263. } else {
  264. // otherwise, partial stop
  265. pos = string_find_partial_stop(text, word);
  266. }
  267. if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
  268. if (is_full_stop) {
  269. stop = STOP_TYPE_WORD;
  270. stopping_word = word;
  271. has_next_token = false;
  272. }
  273. stop_pos = pos;
  274. }
  275. }
  276. return stop_pos;
  277. }
  278. void print_timings() const {
  279. const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
  280. const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
  281. const double t_gen = t_token_generation / n_decoded;
  282. const double n_gen_second = 1e3 / t_token_generation * n_decoded;
  283. SLT_INF(*this,
  284. "\n"
  285. "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
  286. " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
  287. " total time = %10.2f ms / %5d tokens\n",
  288. t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
  289. t_token_generation, n_decoded, t_gen, n_gen_second,
  290. t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
  291. if (n_draft_total > 0) {
  292. const float draft_ratio = (float) n_draft_accepted / n_draft_total;
  293. SLT_CNT(*this,
  294. "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
  295. draft_ratio, n_draft_accepted, n_draft_total
  296. );
  297. }
  298. }
  299. json to_json(bool only_metrics = false) const {
  300. json res;
  301. res = {
  302. {"id", id},
  303. {"n_ctx", n_ctx},
  304. {"speculative", can_speculate()},
  305. {"is_processing", is_processing()},
  306. };
  307. const auto & ptask = task ? task : task_prev;
  308. if (ptask) {
  309. res["id_task"] = ptask->id;
  310. res["params"] = ptask->params.to_json(only_metrics);
  311. res["next_token"] = {
  312. {
  313. {"has_next_token", has_next_token},
  314. {"has_new_line", has_new_line},
  315. {"n_remain", n_remaining},
  316. {"n_decoded", n_decoded},
  317. }
  318. };
  319. if (!only_metrics) {
  320. res["prompt"] = ptask->tokens.detokenize(ctx, true);
  321. res["generated"] = generated_text;
  322. }
  323. }
  324. return res;
  325. }
  326. void copy_state_to(server_slot & other) const {
  327. llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
  328. llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
  329. other.n_decoded = n_decoded;
  330. other.n_remaining = n_remaining;
  331. other.i_batch = i_batch;
  332. other.n_prompt_tokens_cache = n_prompt_tokens_cache;
  333. other.n_prompt_tokens_processed = n_prompt_tokens_processed;
  334. other.prompt = prompt.clone();
  335. }
  336. };
  337. //
  338. // server_metrics
  339. //
  340. struct server_metrics {
  341. int64_t t_start = 0;
  342. uint64_t n_prompt_tokens_processed_total = 0;
  343. uint64_t t_prompt_processing_total = 0;
  344. uint64_t n_tokens_predicted_total = 0;
  345. uint64_t t_tokens_generation_total = 0;
  346. uint64_t n_tokens_max = 0;
  347. uint64_t n_prompt_tokens_processed = 0;
  348. uint64_t t_prompt_processing = 0;
  349. uint64_t n_tokens_predicted = 0;
  350. uint64_t t_tokens_generation = 0;
  351. uint64_t n_decode_total = 0;
  352. uint64_t n_busy_slots_total = 0;
  353. void init() {
  354. t_start = ggml_time_us();
  355. }
  356. void on_prompt_eval(const server_slot & slot) {
  357. n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
  358. n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
  359. t_prompt_processing += slot.t_prompt_processing;
  360. t_prompt_processing_total += slot.t_prompt_processing;
  361. n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
  362. }
  363. void on_prediction(const server_slot & slot) {
  364. n_tokens_predicted_total += slot.n_decoded;
  365. n_tokens_predicted += slot.n_decoded;
  366. t_tokens_generation += slot.t_token_generation;
  367. t_tokens_generation_total += slot.t_token_generation;
  368. }
  369. void on_decoded(const std::vector<server_slot> & slots) {
  370. n_decode_total++;
  371. for (const auto & slot : slots) {
  372. if (slot.is_processing()) {
  373. n_busy_slots_total++;
  374. }
  375. n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
  376. }
  377. }
  378. void reset_bucket() {
  379. n_prompt_tokens_processed = 0;
  380. t_prompt_processing = 0;
  381. n_tokens_predicted = 0;
  382. t_tokens_generation = 0;
  383. }
  384. };
  385. //
  386. // server_context_impl (private implementation)
  387. //
  388. struct server_context_impl {
  389. friend struct server_context;
  390. public:
  391. // only use these pointers outside of this class:
  392. // - when not in sleeping state
  393. // - and, with thread-safe APIs (e.g., tokenizer calls)
  394. llama_model * model = nullptr;
  395. mtmd_context * mctx = nullptr;
  396. const llama_vocab * vocab = nullptr;
  397. server_queue queue_tasks;
  398. server_response queue_results;
  399. common_chat_templates_ptr chat_templates;
  400. oaicompat_parser_options oai_parser_opt;
  401. ~server_context_impl() {
  402. if (!sleeping) {
  403. // destroy() is already called when entering sleeping state
  404. // we don't call it again here to avoid double free
  405. destroy();
  406. }
  407. }
  408. private:
  409. // note: accessing these fields outside of this class is not thread-safe
  410. // use server_context methods instead
  411. common_params params_base;
  412. // note: keep these alive - they determine the lifetime of the model, context, etc.
  413. common_init_result_ptr llama_init;
  414. common_init_result_ptr llama_init_dft;
  415. llama_context * ctx = nullptr;
  416. bool vocab_dft_compatible = true;
  417. llama_model * model_dft = nullptr;
  418. llama_context_params cparams_dft;
  419. llama_batch batch {};
  420. bool add_bos_token = true;
  421. int32_t n_ctx; // total context for all clients / slots
  422. // slots / clients
  423. std::vector<server_slot> slots;
  424. int slots_debug = 0;
  425. std::unique_ptr<server_prompt_cache> prompt_cache;
  426. server_metrics metrics;
  427. json json_webui_settings = json::object();
  428. // Necessary similarity of prompt for slot selection
  429. float slot_prompt_similarity = 0.0f;
  430. std::string model_name; // name of the loaded model, to be used by API
  431. bool sleeping = false;
  432. void destroy() {
  433. llama_init.reset();
  434. ctx = nullptr;
  435. model = nullptr;
  436. mtmd_free(mctx);
  437. mctx = nullptr;
  438. // Clear any sampling context
  439. for (server_slot & slot : slots) {
  440. llama_free(slot.ctx_dft);
  441. slot.ctx_dft = nullptr;
  442. common_speculative_free(slot.spec);
  443. slot.spec = nullptr;
  444. llama_batch_free(slot.batch_spec);
  445. }
  446. llama_batch_free(batch);
  447. }
  448. void handle_sleeping_state(bool new_state) {
  449. GGML_ASSERT(sleeping != new_state);
  450. if (new_state) {
  451. SRV_INF("%s", "server is entering sleeping state\n");
  452. destroy();
  453. } else {
  454. SRV_INF("%s", "server is exiting sleeping state\n");
  455. if (!load_model(params_base)) {
  456. GGML_ABORT("failed to reload model after sleeping");
  457. }
  458. }
  459. sleeping = new_state;
  460. }
  461. // load the model and initialize llama_context
  462. // this may also be called to resume from sleeping state
  463. bool load_model(const common_params & params) {
  464. bool is_resume = sleeping;
  465. SRV_INF("loading model '%s'\n", params.model.path.c_str());
  466. params_base = params;
  467. llama_init = common_init_from_params(params_base);
  468. model = llama_init->model();
  469. ctx = llama_init->context();
  470. if (model == nullptr) {
  471. SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
  472. return false;
  473. }
  474. vocab = llama_model_get_vocab(model);
  475. n_ctx = llama_n_ctx(ctx);
  476. add_bos_token = llama_vocab_get_add_bos(vocab);
  477. if (params_base.has_speculative()) {
  478. SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
  479. auto params_dft = params_base;
  480. params_dft.devices = params_base.speculative.devices;
  481. params_dft.model = params_base.speculative.model;
  482. params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
  483. params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
  484. params_dft.n_parallel = 1;
  485. params_dft.cache_type_k = params_base.speculative.cache_type_k;
  486. params_dft.cache_type_v = params_base.speculative.cache_type_v;
  487. params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
  488. params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
  489. params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
  490. llama_init_dft = common_init_from_params(params_dft);
  491. model_dft = llama_init_dft->model();
  492. if (model_dft == nullptr) {
  493. SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
  494. return false;
  495. }
  496. vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
  497. if (!vocab_dft_compatible) {
  498. SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
  499. }
  500. const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
  501. cparams_dft = common_context_params_to_llama(params_dft);
  502. cparams_dft.n_batch = n_ctx_dft;
  503. // the context is not needed - we will create one for each slot
  504. llama_init_dft->free_context();
  505. }
  506. chat_templates = common_chat_templates_init(model, params_base.chat_template);
  507. try {
  508. common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs);
  509. } catch (const std::exception & e) {
  510. SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
  511. SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
  512. chat_templates = common_chat_templates_init(model, "chatml");
  513. }
  514. std::string & mmproj_path = params_base.mmproj.path;
  515. if (!mmproj_path.empty()) {
  516. if (!is_resume) {
  517. mtmd_helper_log_set(common_log_default_callback, nullptr);
  518. }
  519. mtmd_context_params mparams = mtmd_context_params_default();
  520. mparams.use_gpu = params_base.mmproj_use_gpu;
  521. mparams.print_timings = false;
  522. mparams.n_threads = params_base.cpuparams.n_threads;
  523. mparams.flash_attn_type = params_base.flash_attn_type;
  524. mparams.warmup = params_base.warmup;
  525. mparams.image_min_tokens = params_base.image_min_tokens;
  526. mparams.image_max_tokens = params_base.image_max_tokens;
  527. mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
  528. if (mctx == nullptr) {
  529. SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
  530. return false;
  531. }
  532. SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
  533. if (params_base.ctx_shift) {
  534. params_base.ctx_shift = false;
  535. SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
  536. }
  537. if (params_base.n_cache_reuse) {
  538. params_base.n_cache_reuse = 0;
  539. SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
  540. }
  541. if (params_base.has_speculative()) {
  542. SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
  543. return false;
  544. }
  545. }
  546. if (!llama_memory_can_shift(llama_get_memory(ctx))) {
  547. if (params_base.ctx_shift) {
  548. params_base.ctx_shift = false;
  549. SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
  550. }
  551. if (params_base.n_cache_reuse) {
  552. params_base.n_cache_reuse = 0;
  553. SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
  554. }
  555. }
  556. // Necessary similarity of prompt for slot selection
  557. slot_prompt_similarity = params_base.slot_prompt_similarity;
  558. // setup slots
  559. SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
  560. const int n_ctx_train = llama_model_n_ctx_train(model);
  561. int n_ctx_slot = llama_n_ctx_seq(ctx);
  562. if (n_ctx_slot > n_ctx_train) {
  563. SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
  564. n_ctx_slot = n_ctx_train;
  565. }
  566. slots.clear();
  567. for (int i = 0; i < params_base.n_parallel; i++) {
  568. server_slot slot;
  569. slot.id = i;
  570. slot.ctx = ctx;
  571. slot.n_ctx = n_ctx_slot;
  572. slot.mctx = mctx;
  573. slot.prompt.tokens.has_mtmd = mctx != nullptr;
  574. if (model_dft) {
  575. slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
  576. // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
  577. slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
  578. if (slot.ctx_dft == nullptr) {
  579. SRV_ERR("%s", "failed to create draft context\n");
  580. return false;
  581. }
  582. slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
  583. if (slot.spec == nullptr) {
  584. SRV_ERR("%s", "failed to create speculator\n");
  585. return false;
  586. }
  587. for (auto & pair : params_base.speculative.replacements) {
  588. common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
  589. }
  590. }
  591. SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
  592. slot.callback_on_release = [this](int) {
  593. queue_tasks.pop_deferred_task();
  594. };
  595. slot.reset();
  596. slots.push_back(std::move(slot));
  597. }
  598. {
  599. const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
  600. slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
  601. if (slots_debug) {
  602. SRV_WRN("slots debug = %d\n", slots_debug);
  603. }
  604. }
  605. // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
  606. // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
  607. {
  608. const int32_t n_batch = llama_n_batch(ctx);
  609. batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
  610. }
  611. if (params_base.cache_ram_mib != 0) {
  612. if (params_base.cache_ram_mib < 0) {
  613. SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
  614. } else {
  615. SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
  616. }
  617. SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
  618. prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
  619. } else {
  620. SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
  621. }
  622. SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
  623. if (!params_base.model_alias.empty()) {
  624. // user explicitly specified model name
  625. model_name = params_base.model_alias;
  626. } else if (!params_base.model.name.empty()) {
  627. // use model name in registry format (for models in cache)
  628. model_name = params_base.model.name;
  629. } else {
  630. // fallback: derive model name from file name
  631. auto model_path = std::filesystem::path(params_base.model.path);
  632. model_name = model_path.filename().string();
  633. }
  634. // thinking is enabled if:
  635. // 1. It's not explicitly disabled (reasoning_budget == 0)
  636. // 2. The chat template supports it
  637. const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
  638. SRV_INF("thinking = %d\n", enable_thinking);
  639. oai_parser_opt = {
  640. /* use_jinja */ params_base.use_jinja,
  641. /* prefill_assistant */ params_base.prefill_assistant,
  642. /* reasoning_format */ params_base.reasoning_format,
  643. /* chat_template_kwargs */ params_base.default_template_kwargs,
  644. /* common_chat_templates */ chat_templates.get(),
  645. /* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
  646. /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
  647. /* enable_thinking */ enable_thinking,
  648. /* media_path */ params_base.media_path,
  649. };
  650. // print sample chat example to make it clear which template is used
  651. LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
  652. common_chat_templates_source(chat_templates.get()),
  653. common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
  654. if (!is_resume) {
  655. return init();
  656. }
  657. return true;
  658. }
  659. // unlike load_model(), this is only called once during initialization
  660. bool init() {
  661. GGML_ASSERT(ctx != nullptr);
  662. GGML_ASSERT(model != nullptr);
  663. GGML_ASSERT(!sleeping);
  664. // wiring up server queues
  665. queue_tasks.on_new_task([this](server_task && task) {
  666. process_single_task(std::move(task));
  667. });
  668. queue_tasks.on_update_slots([this]() {
  669. update_slots();
  670. });
  671. queue_tasks.on_sleeping_state([this](bool sleeping) {
  672. handle_sleeping_state(sleeping);
  673. });
  674. metrics.init();
  675. // populate webui settings
  676. {
  677. if (!params_base.webui_config_json.empty()) {
  678. try {
  679. json_webui_settings = json::parse(params_base.webui_config_json);
  680. } catch (const std::exception & e) {
  681. SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
  682. return false;
  683. }
  684. }
  685. }
  686. return true;
  687. }
  688. server_slot * get_slot_by_id(int id) {
  689. for (server_slot & slot : slots) {
  690. if (slot.id == id) {
  691. return &slot;
  692. }
  693. }
  694. return nullptr;
  695. }
  696. server_slot * get_available_slot(const server_task & task) {
  697. server_slot * ret = nullptr;
  698. bool update_cache = false;
  699. // find the slot that has at least n% prompt similarity
  700. if (ret == nullptr && slot_prompt_similarity != 0.0f) {
  701. float sim_best = 0;
  702. for (server_slot & slot : slots) {
  703. // skip the slot if it is not available
  704. if (slot.is_processing()) {
  705. continue;
  706. }
  707. const auto & tokens = slot.prompt.tokens;
  708. // skip the slot if it does not contains cached tokens
  709. if (tokens.empty()) {
  710. continue;
  711. }
  712. // fraction of the Longest Common Prefix length with respect to the input prompt length
  713. const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
  714. // select the current slot if the criteria match
  715. if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
  716. sim_best = sim_cur;
  717. ret = &slot;
  718. }
  719. }
  720. if (ret != nullptr) {
  721. const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
  722. SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
  723. sim_best, slot_prompt_similarity, f_keep);
  724. // if we are about to lose a large portion of the existing context - save it in the prompt cache
  725. if (f_keep < 0.5f) {
  726. update_cache = true;
  727. }
  728. }
  729. }
  730. // find the slot that has been least recently used
  731. if (ret == nullptr) {
  732. int64_t t_last = -1;
  733. for (server_slot & slot : slots) {
  734. // skip the slot if it is not available
  735. if (slot.is_processing()) {
  736. continue;
  737. }
  738. // select the current slot if the criteria match
  739. if (!ret || slot.t_last_used <= t_last) {
  740. t_last = slot.t_last_used;
  741. ret = &slot;
  742. }
  743. }
  744. if (ret != nullptr) {
  745. SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
  746. update_cache = true;
  747. }
  748. }
  749. if (ret) {
  750. const auto & tokens = ret->prompt.tokens;
  751. update_cache = update_cache && prompt_cache;
  752. // cache prompts only for completion tasks
  753. update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
  754. // don't update the cache if the slot's context is empty
  755. update_cache = update_cache && tokens.size() > 0;
  756. // TODO: mtmd does not support prompt cache
  757. update_cache = update_cache && (ret->mctx == nullptr);
  758. if (update_cache) {
  759. SRV_WRN("%s", "updating prompt cache\n");
  760. const int64_t t_start = ggml_time_us();
  761. ret->prompt_save(*prompt_cache);
  762. if (!ret->prompt_load(*prompt_cache, task.tokens)) {
  763. clear_slot(*ret);
  764. }
  765. prompt_cache->update();
  766. SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
  767. }
  768. }
  769. return ret;
  770. }
  771. void clear_slot(server_slot & slot, bool allow_processing = false) const {
  772. if (!allow_processing) {
  773. GGML_ASSERT(!slot.is_processing());
  774. }
  775. SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
  776. llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
  777. slot.prompt.tokens.clear();
  778. }
  779. // return true if at least one slot has been cleared
  780. // TODO: improve logic
  781. // - smarter decision which slot to clear (LRU or longest prompt?)
  782. // - move slot to level 2 cache instead of removing?
  783. // - instead of purging, try to store and resume later?
  784. bool try_clear_idle_slots() {
  785. bool res = false;
  786. if (!params_base.kv_unified) {
  787. return res;
  788. }
  789. for (auto & slot : slots) {
  790. if (slot.is_processing()) {
  791. continue;
  792. }
  793. if (slot.prompt.n_tokens() > 0) {
  794. SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
  795. clear_slot(slot);
  796. res = true;
  797. // clear slots one by one
  798. break;
  799. }
  800. }
  801. return res;
  802. }
  803. std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
  804. std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
  805. for (size_t i = 0; i < output.size(); ++i) {
  806. auto it = config.find(i);
  807. if (it != config.end()) {
  808. output[i].scale = it->second;
  809. } else {
  810. output[i].scale = 0.0f;
  811. }
  812. }
  813. return output;
  814. }
  815. bool launch_slot_with_task(server_slot & slot, server_task && task) {
  816. slot.reset();
  817. // process per-request lora adapters
  818. if (!task.params.lora.empty()) {
  819. auto task_loras = construct_lora_list(task.params.lora);
  820. if (!are_lora_equal(task_loras, slot.lora)) {
  821. // if lora has changed, check to see if the cache should be cleared
  822. if (lora_should_clear_cache(slot.lora, task_loras)) {
  823. SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
  824. slot.prompt.tokens.clear();
  825. } else {
  826. SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task_loras.size());
  827. }
  828. slot.lora = task_loras;
  829. }
  830. } else {
  831. slot.lora = params_base.lora_adapters;
  832. }
  833. // if using alora, make sure it's only a single one requested and active
  834. size_t alora_invocation_start = task.tokens.size();
  835. if (lora_all_alora(slot.lora)) {
  836. const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
  837. // TODO: This will error out if a user requests two aloras, but only
  838. // provides the activation string for one. We could, instead search
  839. // for all requested alora activation strings and then either keep
  840. // only the last one, or reject if multiple are found.
  841. if (enabled_ids.size() != 1) {
  842. send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
  843. return false;
  844. }
  845. const auto & lora = slot.lora[enabled_ids[0]].ptr;
  846. // get the pointer and count for the invocation tokens
  847. const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
  848. const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
  849. // scan backwards through the prompt tokens to find the last
  850. // occurrence of the invocation sequence
  851. int match_idx = static_cast<int>(n_invocation_tokens) - 1;
  852. for (int i = task.tokens.size() - 1; i >= 0; --i) {
  853. // the token in this position matches the next token to find in
  854. // the invocation sequence
  855. if (task.tokens[i] == invocation_tokens[match_idx]) {
  856. // if it's a full match, we've found the start
  857. if (match_idx == 0) {
  858. alora_invocation_start = i;
  859. break;
  860. }
  861. // otherwise, check the next token in the sequence
  862. --match_idx;
  863. } else {
  864. // no match in this position, so start looking over again
  865. match_idx = static_cast<int>(n_invocation_tokens) - 1;
  866. }
  867. }
  868. // if the activation string is not found, disable the alora
  869. if (alora_invocation_start == task.tokens.size()) {
  870. SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
  871. slot.lora[enabled_ids[0]].scale = 0.0f;
  872. } else {
  873. SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
  874. slot.alora_invocation_start = alora_invocation_start;
  875. }
  876. }
  877. if (!task.tokens.validate(ctx)) {
  878. send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
  879. return false;
  880. }
  881. SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
  882. // initialize samplers
  883. {
  884. slot.smpl.reset(common_sampler_init(model, task.params.sampling));
  885. if (slot.smpl == nullptr) {
  886. // for now, the only error that may happen here is invalid grammar
  887. send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
  888. return false;
  889. }
  890. const bool need_logits = task.params.sampling.n_probs > 0;
  891. bool backend_sampling = true;
  892. backend_sampling &= task.params.sampling.backend_sampling;
  893. // TODO: speculative decoding requires multiple samples per batch - not supported yet
  894. backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
  895. // TODO: getting post/pre sampling logits is not yet supported with backend sampling
  896. backend_sampling &= !need_logits;
  897. // TODO: tmp until backend sampling is fully implemented
  898. if (backend_sampling) {
  899. llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
  900. } else {
  901. llama_set_sampler(ctx, slot.id, nullptr);
  902. }
  903. SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
  904. }
  905. // initialize draft batch
  906. // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
  907. if (slot.ctx_dft) {
  908. llama_batch_free(slot.batch_spec);
  909. slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
  910. }
  911. slot.task = std::make_unique<const server_task>(std::move(task));
  912. slot.state = slot.is_child()
  913. ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
  914. : SLOT_STATE_STARTED;
  915. SLT_INF(slot, "%s", "processing task\n");
  916. return true;
  917. }
  918. bool process_token(completion_token_output & result, server_slot & slot) {
  919. // remember which tokens were sampled - used for repetition penalties during sampling
  920. const std::string token_str = result.text_to_send;
  921. slot.sampled = result.tok;
  922. slot.generated_text += token_str;
  923. if (slot.task->params.return_tokens) {
  924. slot.generated_tokens.push_back(result.tok);
  925. }
  926. slot.has_next_token = true;
  927. // check if there is incomplete UTF-8 character at the end
  928. bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
  929. // search stop word and delete it
  930. if (!incomplete) {
  931. size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
  932. const std::string str_test = slot.generated_text.substr(pos);
  933. bool send_text = true;
  934. size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
  935. if (stop_pos != std::string::npos) {
  936. slot.generated_text.erase(
  937. slot.generated_text.begin() + pos + stop_pos,
  938. slot.generated_text.end());
  939. pos = std::min(slot.n_sent_text, slot.generated_text.size());
  940. } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) {
  941. stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
  942. send_text = stop_pos == std::string::npos;
  943. }
  944. // check if there is any token to predict
  945. if (send_text) {
  946. // no send the stop word in the response
  947. result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
  948. slot.n_sent_text += result.text_to_send.size();
  949. // add the token to slot queue and cache
  950. } else {
  951. result.text_to_send = "";
  952. }
  953. slot.add_token(result);
  954. if (slot.task->params.stream) {
  955. send_partial_response(slot, result, false);
  956. }
  957. }
  958. if (incomplete) {
  959. slot.has_next_token = true;
  960. }
  961. // if context shifting is disabled, make sure that we don't run out of context
  962. if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
  963. slot.truncated = true;
  964. slot.stop = STOP_TYPE_LIMIT;
  965. slot.has_next_token = false;
  966. SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
  967. slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
  968. }
  969. // check the limits
  970. if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
  971. slot.stop = STOP_TYPE_LIMIT;
  972. slot.has_next_token = false;
  973. SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
  974. }
  975. if (slot.has_new_line) {
  976. // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
  977. if (slot.task->params.n_indent > 0) {
  978. // check the current indentation
  979. // TODO: improve by not doing it more than once for each new line
  980. if (slot.last_nl_pos > 0) {
  981. size_t pos = slot.last_nl_pos;
  982. int n_indent = 0;
  983. while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
  984. n_indent++;
  985. pos++;
  986. }
  987. if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
  988. slot.stop = STOP_TYPE_LIMIT;
  989. slot.has_next_token = false;
  990. // cut the last line
  991. slot.generated_text.erase(pos, std::string::npos);
  992. SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
  993. }
  994. }
  995. // find the next new line
  996. {
  997. const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
  998. if (pos != std::string::npos) {
  999. slot.last_nl_pos = pos + 1;
  1000. }
  1001. }
  1002. }
  1003. }
  1004. // check if there is a new line in the generated text
  1005. if (result.text_to_send.find('\n') != std::string::npos) {
  1006. slot.has_new_line = true;
  1007. // if we have seen a new line, we stop after a certain time limit, but only upon another new line
  1008. if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
  1009. slot.stop = STOP_TYPE_LIMIT;
  1010. slot.has_next_token = false;
  1011. SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
  1012. }
  1013. }
  1014. if (llama_vocab_is_eog(vocab, result.tok)) {
  1015. slot.stop = STOP_TYPE_EOS;
  1016. slot.has_next_token = false;
  1017. SLT_DBG(slot, "%s", "stopped by EOS\n");
  1018. }
  1019. SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
  1020. return slot.has_next_token; // continue
  1021. }
  1022. void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
  1023. const size_t n_probs = slot.task->params.sampling.n_probs;
  1024. if (post_sampling) {
  1025. const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
  1026. const size_t max_probs = cur_p->size;
  1027. // set probability for sampled token
  1028. for (size_t i = 0; i < max_probs; i++) {
  1029. if (cur_p->data[i].id == result.tok) {
  1030. result.prob = cur_p->data[i].p;
  1031. break;
  1032. }
  1033. }
  1034. // set probability for top n_probs tokens
  1035. result.probs.reserve(max_probs);
  1036. for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
  1037. result.probs.push_back({
  1038. cur_p->data[i].id,
  1039. common_token_to_piece(ctx, cur_p->data[i].id, special),
  1040. cur_p->data[i].p
  1041. });
  1042. }
  1043. } else {
  1044. // TODO: optimize this with min-p optimization
  1045. std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
  1046. // set probability for sampled token
  1047. for (size_t i = 0; i < cur.size(); i++) {
  1048. // set probability for sampled token
  1049. if (cur[i].id == result.tok) {
  1050. result.prob = cur[i].p;
  1051. break;
  1052. }
  1053. }
  1054. // set probability for top n_probs tokens
  1055. result.probs.reserve(n_probs);
  1056. for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
  1057. result.probs.push_back({
  1058. cur[i].id,
  1059. common_token_to_piece(ctx, cur[i].id, special),
  1060. cur[i].p
  1061. });
  1062. }
  1063. }
  1064. }
  1065. void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
  1066. send_error(task.id, error, type);
  1067. }
  1068. void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
  1069. send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
  1070. }
  1071. void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
  1072. SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
  1073. if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
  1074. GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
  1075. }
  1076. auto res = std::make_unique<server_task_result_error>();
  1077. res->id = id_task;
  1078. res->err_type = type;
  1079. res->err_msg = error;
  1080. res->n_prompt_tokens = n_prompt_tokens;
  1081. res->n_ctx = n_ctx;
  1082. queue_results.send(std::move(res));
  1083. }
  1084. // if multimodal is enabled, send an error and return false
  1085. bool check_no_mtmd(const int id_task) {
  1086. if (mctx) {
  1087. send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
  1088. return false;
  1089. }
  1090. return true;
  1091. }
  1092. void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
  1093. auto res = std::make_unique<server_task_result_cmpl_partial>();
  1094. res->id = slot.task->id;
  1095. res->index = slot.task->index;
  1096. if (is_progress) {
  1097. res->is_progress = true;
  1098. res->progress.total = slot.task->n_tokens();
  1099. res->progress.cache = slot.n_prompt_tokens_cache;
  1100. res->progress.processed = slot.prompt.tokens.size();
  1101. res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
  1102. } else {
  1103. res->content = tkn.text_to_send;
  1104. res->tokens = { tkn.tok };
  1105. }
  1106. res->n_decoded = slot.n_decoded;
  1107. res->n_prompt_tokens = slot.task->n_tokens();
  1108. res->post_sampling_probs = slot.task->params.post_sampling_probs;
  1109. res->verbose = slot.task->params.verbose;
  1110. res->res_type = slot.task->params.res_type;
  1111. res->oaicompat_model = slot.task->params.oaicompat_model;
  1112. res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
  1113. // populate res.probs_output
  1114. if (slot.task->params.sampling.n_probs > 0) {
  1115. res->prob_output = tkn; // copy the token probs
  1116. }
  1117. // populate timings if this is final response or timings_per_token is enabled
  1118. if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
  1119. res->timings = slot.get_timings();
  1120. }
  1121. queue_results.send(std::move(res));
  1122. }
  1123. void send_final_response(server_slot & slot) {
  1124. auto res = std::make_unique<server_task_result_cmpl_final>();
  1125. res->id = slot.task->id;
  1126. res->id_slot = slot.id;
  1127. res->index = slot.task->index;
  1128. // in stream mode, content and tokens are already in last partial chunk
  1129. if (slot.task->params.stream) {
  1130. res->content = "";
  1131. res->tokens = llama_tokens{};
  1132. } else {
  1133. res->content = std::move(slot.generated_text);
  1134. res->tokens = std::move(slot.generated_tokens);
  1135. }
  1136. res->timings = slot.get_timings();
  1137. res->prompt = slot.task->tokens.detokenize(ctx, true);
  1138. res->response_fields = std::move(slot.task->params.response_fields);
  1139. res->truncated = slot.truncated;
  1140. res->n_decoded = slot.n_decoded;
  1141. res->n_prompt_tokens = slot.task->n_tokens();
  1142. res->n_tokens_cached = slot.prompt.n_tokens();
  1143. res->has_new_line = slot.has_new_line;
  1144. res->stopping_word = slot.stopping_word;
  1145. res->stop = slot.stop;
  1146. res->post_sampling_probs = slot.task->params.post_sampling_probs;
  1147. res->verbose = slot.task->params.verbose;
  1148. res->stream = slot.task->params.stream;
  1149. res->include_usage = slot.task->params.include_usage;
  1150. res->res_type = slot.task->params.res_type;
  1151. res->oaicompat_model = slot.task->params.oaicompat_model;
  1152. res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
  1153. // populate res.probs_output
  1154. if (slot.task->params.sampling.n_probs > 0) {
  1155. if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
  1156. const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
  1157. size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
  1158. res->probs_output = std::vector<completion_token_output>(
  1159. slot.generated_token_probs.begin(),
  1160. slot.generated_token_probs.end() - safe_offset);
  1161. } else {
  1162. res->probs_output = std::vector<completion_token_output>(
  1163. slot.generated_token_probs.begin(),
  1164. slot.generated_token_probs.end());
  1165. }
  1166. }
  1167. res->generation_params = slot.task->params; // copy the parameters
  1168. queue_results.send(std::move(res));
  1169. }
  1170. void send_embedding(const server_slot & slot, const llama_batch & batch) {
  1171. auto res = std::make_unique<server_task_result_embd>();
  1172. res->id = slot.task->id;
  1173. res->index = slot.task->index;
  1174. res->n_tokens = slot.task->n_tokens();
  1175. res->res_type = slot.task->params.res_type;
  1176. const int n_embd_out = llama_model_n_embd_out(model);
  1177. std::vector<float> embd_res(n_embd_out, 0.0f);
  1178. for (int i = 0; i < batch.n_tokens; ++i) {
  1179. if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
  1180. continue;
  1181. }
  1182. const float * embd = nullptr;
  1183. if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
  1184. embd = llama_get_embeddings_ith(ctx, i);
  1185. } else {
  1186. embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
  1187. }
  1188. if (embd == nullptr) {
  1189. SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
  1190. res->embedding.push_back(std::vector<float>(n_embd_out, 0.0f));
  1191. continue;
  1192. }
  1193. // normalize only when there is pooling
  1194. if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
  1195. common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
  1196. res->embedding.push_back(embd_res);
  1197. break;
  1198. }
  1199. res->embedding.emplace_back(embd, embd + n_embd_out);
  1200. }
  1201. SLT_DBG(slot, "%s", "sending embeddings\n");
  1202. queue_results.send(std::move(res));
  1203. }
  1204. void send_rerank(const server_slot & slot, const llama_batch & batch) {
  1205. auto res = std::make_unique<server_task_result_rerank>();
  1206. res->id = slot.task->id;
  1207. res->index = slot.task->index;
  1208. res->n_tokens = slot.task->n_tokens();
  1209. for (int i = 0; i < batch.n_tokens; ++i) {
  1210. if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
  1211. continue;
  1212. }
  1213. const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
  1214. if (embd == NULL) {
  1215. embd = llama_get_embeddings_ith(ctx, i);
  1216. }
  1217. if (embd == NULL) {
  1218. SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
  1219. res->score = -1e6;
  1220. continue;
  1221. }
  1222. res->score = embd[0];
  1223. }
  1224. SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
  1225. queue_results.send(std::move(res));
  1226. }
  1227. //
  1228. // Functions to process the task
  1229. //
  1230. // tokenize the input if it's set by CLI, return false on error
  1231. bool tokenize_cli_input(server_task & task) {
  1232. if (task.cli_input == nullptr) {
  1233. return true; // nothing to do
  1234. }
  1235. try {
  1236. auto & opt = oai_parser_opt;
  1237. common_chat_templates_inputs inputs;
  1238. inputs.messages = common_chat_msgs_parse_oaicompat(task.cli_input);
  1239. inputs.tools = {}; // TODO
  1240. inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
  1241. inputs.json_schema = ""; // TODO
  1242. inputs.grammar = ""; // TODO
  1243. inputs.use_jinja = opt.use_jinja;
  1244. inputs.parallel_tool_calls = false;
  1245. inputs.add_generation_prompt = true;
  1246. inputs.reasoning_format = opt.reasoning_format;
  1247. inputs.enable_thinking = opt.enable_thinking;
  1248. // Apply chat template to the list of messages
  1249. auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
  1250. // tokenize the resulting prompt
  1251. auto & prompt = chat_params.prompt;
  1252. if (mctx != nullptr) {
  1253. task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files);
  1254. } else {
  1255. task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]);
  1256. }
  1257. task.cli_input.clear();
  1258. task.cli_files.clear();
  1259. } catch (const std::exception & e) {
  1260. send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
  1261. return false;
  1262. }
  1263. return true;
  1264. }
  1265. void process_single_task(server_task && task) {
  1266. switch (task.type) {
  1267. case SERVER_TASK_TYPE_COMPLETION:
  1268. case SERVER_TASK_TYPE_INFILL:
  1269. case SERVER_TASK_TYPE_EMBEDDING:
  1270. case SERVER_TASK_TYPE_RERANK:
  1271. {
  1272. if (!tokenize_cli_input(task)) {
  1273. break;
  1274. }
  1275. const int id_slot = task.id_slot;
  1276. server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
  1277. if (slot == nullptr) {
  1278. // if no slot is available, we defer this task for processing later
  1279. SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
  1280. queue_tasks.defer(std::move(task));
  1281. break;
  1282. }
  1283. if (slot->is_processing()) {
  1284. // if requested slot is unavailable, we defer this task for processing later
  1285. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1286. queue_tasks.defer(std::move(task));
  1287. break;
  1288. }
  1289. if (!launch_slot_with_task(*slot, std::move(task))) {
  1290. SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
  1291. break;
  1292. }
  1293. } break;
  1294. case SERVER_TASK_TYPE_CANCEL:
  1295. {
  1296. // release slot linked with the task id
  1297. for (auto & slot : slots) {
  1298. if (slot.task && slot.task->id == task.id_target) {
  1299. slot.release();
  1300. break;
  1301. }
  1302. }
  1303. } break;
  1304. case SERVER_TASK_TYPE_NEXT_RESPONSE:
  1305. {
  1306. // do nothing
  1307. } break;
  1308. case SERVER_TASK_TYPE_METRICS:
  1309. {
  1310. json slots_data = json::array();
  1311. int n_idle_slots = 0;
  1312. int n_processing_slots = 0;
  1313. for (server_slot & slot : slots) {
  1314. json slot_data = slot.to_json(slots_debug == 0);
  1315. if (slot.is_processing()) {
  1316. n_processing_slots++;
  1317. } else {
  1318. n_idle_slots++;
  1319. }
  1320. slots_data.push_back(slot_data);
  1321. }
  1322. SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
  1323. auto res = std::make_unique<server_task_result_metrics>();
  1324. res->id = task.id;
  1325. res->slots_data = std::move(slots_data);
  1326. res->n_idle_slots = n_idle_slots;
  1327. res->n_processing_slots = n_processing_slots;
  1328. res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size();
  1329. res->t_start = metrics.t_start;
  1330. res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
  1331. res->t_prompt_processing_total = metrics.t_prompt_processing_total;
  1332. res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
  1333. res->t_tokens_generation_total = metrics.t_tokens_generation_total;
  1334. res->n_tokens_max = metrics.n_tokens_max;
  1335. res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
  1336. res->t_prompt_processing = metrics.t_prompt_processing;
  1337. res->n_tokens_predicted = metrics.n_tokens_predicted;
  1338. res->t_tokens_generation = metrics.t_tokens_generation;
  1339. res->n_decode_total = metrics.n_decode_total;
  1340. res->n_busy_slots_total = metrics.n_busy_slots_total;
  1341. if (task.metrics_reset_bucket) {
  1342. metrics.reset_bucket();
  1343. }
  1344. queue_results.send(std::move(res));
  1345. } break;
  1346. case SERVER_TASK_TYPE_SLOT_SAVE:
  1347. {
  1348. if (!check_no_mtmd(task.id)) {
  1349. break;
  1350. }
  1351. int id_slot = task.slot_action.slot_id;
  1352. server_slot * slot = get_slot_by_id(id_slot);
  1353. if (slot == nullptr) {
  1354. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1355. break;
  1356. }
  1357. if (slot->is_processing()) {
  1358. // if requested slot is unavailable, we defer this task for processing later
  1359. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1360. queue_tasks.defer(std::move(task));
  1361. break;
  1362. }
  1363. const size_t token_count = slot->prompt.tokens.size();
  1364. const int64_t t_start = ggml_time_us();
  1365. std::string filename = task.slot_action.filename;
  1366. std::string filepath = task.slot_action.filepath;
  1367. const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
  1368. const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
  1369. const int64_t t_end = ggml_time_us();
  1370. const double t_save_ms = (t_end - t_start) / 1000.0;
  1371. auto res = std::make_unique<server_task_result_slot_save_load>();
  1372. res->id = task.id;
  1373. res->id_slot = id_slot;
  1374. res->filename = filename;
  1375. res->is_save = true;
  1376. res->n_tokens = token_count;
  1377. res->n_bytes = nwrite;
  1378. res->t_ms = t_save_ms;
  1379. queue_results.send(std::move(res));
  1380. } break;
  1381. case SERVER_TASK_TYPE_SLOT_RESTORE:
  1382. {
  1383. if (!check_no_mtmd(task.id)) break;
  1384. int id_slot = task.slot_action.slot_id;
  1385. server_slot * slot = get_slot_by_id(id_slot);
  1386. if (slot == nullptr) {
  1387. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1388. break;
  1389. }
  1390. if (slot->is_processing()) {
  1391. // if requested slot is unavailable, we defer this task for processing later
  1392. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1393. queue_tasks.defer(std::move(task));
  1394. break;
  1395. }
  1396. const int64_t t_start = ggml_time_us();
  1397. std::string filename = task.slot_action.filename;
  1398. std::string filepath = task.slot_action.filepath;
  1399. llama_tokens tokens;
  1400. tokens.resize(slot->n_ctx);
  1401. size_t token_count = 0;
  1402. size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
  1403. if (nread == 0) {
  1404. slot->prompt.tokens.clear(); // KV may already been invalidated?
  1405. send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
  1406. break;
  1407. }
  1408. tokens.resize(token_count);
  1409. slot->prompt.tokens.clear();
  1410. slot->prompt.tokens.insert(tokens);
  1411. const int64_t t_end = ggml_time_us();
  1412. const double t_restore_ms = (t_end - t_start) / 1000.0;
  1413. auto res = std::make_unique<server_task_result_slot_save_load>();
  1414. res->id = task.id;
  1415. res->id_slot = id_slot;
  1416. res->filename = filename;
  1417. res->is_save = false;
  1418. res->n_tokens = token_count;
  1419. res->n_bytes = nread;
  1420. res->t_ms = t_restore_ms;
  1421. queue_results.send(std::move(res));
  1422. } break;
  1423. case SERVER_TASK_TYPE_SLOT_ERASE:
  1424. {
  1425. if (!check_no_mtmd(task.id)) {
  1426. break;
  1427. }
  1428. int id_slot = task.slot_action.slot_id;
  1429. server_slot * slot = get_slot_by_id(id_slot);
  1430. if (slot == nullptr) {
  1431. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1432. break;
  1433. }
  1434. if (slot->is_processing()) {
  1435. // if requested slot is unavailable, we defer this task for processing later
  1436. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1437. queue_tasks.defer(std::move(task));
  1438. break;
  1439. }
  1440. // Erase token cache
  1441. const size_t n_erased = slot->prompt.tokens.size();
  1442. clear_slot(*slot);
  1443. auto res = std::make_unique<server_task_result_slot_erase>();
  1444. res->id = task.id;
  1445. res->id_slot = id_slot;
  1446. res->n_erased = n_erased;
  1447. queue_results.send(std::move(res));
  1448. } break;
  1449. case SERVER_TASK_TYPE_GET_LORA:
  1450. {
  1451. // TODO @ngxson : make lora_adapters a dedicated member of server_context
  1452. auto & loras = params_base.lora_adapters;
  1453. auto res = std::make_unique<server_task_result_get_lora>();
  1454. res->id = task.id;
  1455. for (size_t i = 0; i < loras.size(); ++i) {
  1456. auto & lora = loras[i];
  1457. std::string alora_invocation_string = "";
  1458. const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
  1459. llama_tokens alora_invocation_tokens;
  1460. if (n_alora_tokens) {
  1461. const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
  1462. for (uint64_t j = 0; j < n_alora_tokens; ++j) {
  1463. alora_invocation_string += common_token_to_piece(vocab, alora_tokens[j]);
  1464. alora_invocation_tokens.push_back(alora_tokens[j]);
  1465. }
  1466. }
  1467. res->loras.push_back(server_task_result_get_lora::lora{
  1468. lora,
  1469. alora_invocation_string,
  1470. alora_invocation_tokens,
  1471. });
  1472. }
  1473. queue_results.send(std::move(res));
  1474. } break;
  1475. case SERVER_TASK_TYPE_SET_LORA:
  1476. {
  1477. auto new_loras = construct_lora_list(task.set_lora);
  1478. // logging
  1479. for (size_t i = 0; i < new_loras.size(); ++i) {
  1480. SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale);
  1481. }
  1482. // TODO @ngxson : make lora_adapters a dedicated member of server_context
  1483. params_base.lora_adapters = new_loras;
  1484. auto res = std::make_unique<server_task_result_apply_lora>();
  1485. res->id = task.id;
  1486. queue_results.send(std::move(res));
  1487. } break;
  1488. }
  1489. }
  1490. void update_slots() {
  1491. // check if all slots are idle
  1492. {
  1493. bool all_idle = true;
  1494. for (auto & slot : slots) {
  1495. if (slot.is_processing()) {
  1496. all_idle = false;
  1497. break;
  1498. }
  1499. }
  1500. if (all_idle) {
  1501. SRV_INF("%s", "all slots are idle\n");
  1502. return;
  1503. }
  1504. }
  1505. {
  1506. SRV_DBG("%s", "posting NEXT_RESPONSE\n");
  1507. server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
  1508. task.id = queue_tasks.get_new_id();
  1509. queue_tasks.post(std::move(task));
  1510. }
  1511. // apply context-shift if needed
  1512. // TODO: simplify and improve
  1513. for (server_slot & slot : slots) {
  1514. if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
  1515. if (!params_base.ctx_shift) {
  1516. // this check is redundant (for good)
  1517. // we should never get here, because generation should already stopped in process_token()
  1518. send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
  1519. slot.release();
  1520. continue;
  1521. }
  1522. if (mctx) {
  1523. // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
  1524. // we don't support ctx_shift because an image chunk may contains multiple tokens
  1525. GGML_ABORT("not supported by multimodal");
  1526. }
  1527. if (slot.is_parent() || slot.is_child()) {
  1528. send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
  1529. slot.release();
  1530. continue;
  1531. }
  1532. // Shift context
  1533. int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
  1534. if (add_bos_token) {
  1535. n_keep += 1;
  1536. }
  1537. n_keep = std::min(slot.n_ctx - 4, n_keep);
  1538. const int n_left = slot.prompt.n_tokens() - n_keep;
  1539. const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
  1540. SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
  1541. llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
  1542. llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
  1543. // add generated tokens to cache
  1544. // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
  1545. {
  1546. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1547. llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
  1548. for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
  1549. new_tokens[i - n_discard] = new_tokens[i];
  1550. }
  1551. new_tokens.resize(slot.prompt.tokens.size() - n_discard);
  1552. slot.prompt.tokens.clear();
  1553. slot.prompt.tokens.insert(new_tokens);
  1554. }
  1555. slot.truncated = true;
  1556. }
  1557. }
  1558. // start populating the batch for this iteration
  1559. common_batch_clear(batch);
  1560. // track if given slot can be batched with slots already in the batch
  1561. server_slot * slot_batched = nullptr;
  1562. auto accept_special_token = [&](server_slot & slot, llama_token token) {
  1563. return params_base.special ||
  1564. slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
  1565. };
  1566. // first, add sampled tokens from any ongoing sequences
  1567. for (auto & slot : slots) {
  1568. if (slot.state != SLOT_STATE_GENERATING) {
  1569. continue;
  1570. }
  1571. // check if we can batch this slot with the previous one
  1572. if (!slot_batched) {
  1573. slot_batched = &slot;
  1574. } else if (!slot_batched->can_batch_with(slot)) {
  1575. continue;
  1576. }
  1577. // generate draft tokens in speculative decoding mode
  1578. // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
  1579. // perform the speculative drafting for all sequences at the same time in a single batch
  1580. int n_draft_max = slot.get_n_draft_max();
  1581. if (n_draft_max > 0) {
  1582. if (mctx) {
  1583. // we should never reach this, as speculative is automatically disabled if mmproj is loaded
  1584. GGML_ABORT("not supported by multimodal");
  1585. }
  1586. struct common_speculative_params params_spec;
  1587. params_spec.n_draft = n_draft_max;
  1588. params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
  1589. params_spec.p_min = slot.task->params.speculative.p_min;
  1590. const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
  1591. llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
  1592. // add the sampled token to the batch
  1593. slot.i_batch_dft.push_back(batch.n_tokens);
  1594. common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
  1595. slot.prompt.tokens.push_back(slot.sampled);
  1596. if (slot.task->params.speculative.n_min > (int) draft.size()) {
  1597. SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
  1598. // fallback to normal decoding
  1599. slot.i_batch = slot.i_batch_dft[0];
  1600. slot.drafted.clear();
  1601. slot.i_batch_dft.clear();
  1602. } else {
  1603. // keep track of total number of drafted tokens tested
  1604. slot.n_draft_total += draft.size();
  1605. // add all drafted tokens to the batch
  1606. for (size_t i = 0; i < draft.size(); i++) {
  1607. slot.i_batch_dft.push_back(batch.n_tokens);
  1608. common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
  1609. slot.prompt.tokens.push_back(draft[i]);
  1610. }
  1611. slot.drafted = std::move(draft);
  1612. }
  1613. } else {
  1614. // no speculative decoding
  1615. slot.i_batch = batch.n_tokens;
  1616. common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
  1617. slot.prompt.tokens.push_back(slot.sampled);
  1618. SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
  1619. slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
  1620. }
  1621. }
  1622. // process in chunks of params.n_batch
  1623. int32_t n_batch = llama_n_batch(ctx);
  1624. int32_t n_ubatch = llama_n_ubatch(ctx);
  1625. float alora_scale = -1.0f;
  1626. size_t alora_disabled_id = 0;
  1627. // next, batch any pending prompts without exceeding n_batch
  1628. if (params_base.cont_batching || batch.n_tokens == 0) {
  1629. for (auto & slot : slots) {
  1630. if (!slot.is_processing()) {
  1631. continue;
  1632. }
  1633. // check if we can batch this slot with the previous one
  1634. if (slot_batched && !slot_batched->can_batch_with(slot)) {
  1635. continue;
  1636. }
  1637. // this slot still has a prompt to be processed
  1638. if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
  1639. const auto & input_tokens = slot.task->tokens;
  1640. // TODO: maybe move branch to outside of this loop in the future
  1641. if (slot.state == SLOT_STATE_STARTED) {
  1642. slot.t_start_process_prompt = ggml_time_us();
  1643. slot.t_start_generation = 0;
  1644. slot.state = SLOT_STATE_PROCESSING_PROMPT;
  1645. SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
  1646. slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
  1647. // print prompt tokens (for debugging)
  1648. /*if (1) {
  1649. // first 16 tokens (avoid flooding logs)
  1650. for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
  1651. SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
  1652. }
  1653. } else {
  1654. // all
  1655. for (int i = 0; i < (int) input_tokens.size(); i++) {
  1656. SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
  1657. }
  1658. }*/
  1659. // keep track how many tokens we can reuse from the previous state
  1660. int n_past = 0;
  1661. // empty prompt passed -> release the slot and send empty response
  1662. if (input_tokens.empty()) {
  1663. SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
  1664. slot.print_timings();
  1665. send_final_response(slot);
  1666. slot.release();
  1667. continue;
  1668. }
  1669. // TODO: support memory-less logits computation
  1670. if (slot.need_logits() && !llama_get_memory(ctx)) {
  1671. send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
  1672. slot.release();
  1673. continue;
  1674. }
  1675. if (!slot.can_split()) {
  1676. if (slot.task->n_tokens() > n_ubatch) {
  1677. send_error(slot,
  1678. string_format(
  1679. "input (%d tokens) is too large to process. increase the physical batch "
  1680. "size (current batch size: %d)",
  1681. slot.task->n_tokens(), n_ubatch),
  1682. ERROR_TYPE_SERVER);
  1683. slot.release();
  1684. continue;
  1685. }
  1686. if (slot.task->n_tokens() > slot.n_ctx) {
  1687. send_error(
  1688. slot,
  1689. string_format(
  1690. "input (%d tokens) is larger than the max context size (%d tokens). skipping",
  1691. slot.task->n_tokens(), slot.n_ctx),
  1692. ERROR_TYPE_EXCEED_CONTEXT_SIZE);
  1693. slot.release();
  1694. continue;
  1695. }
  1696. } else {
  1697. if (slot.task->n_tokens() >= slot.n_ctx) {
  1698. send_error(slot,
  1699. string_format("request (%d tokens) exceeds the available context size (%d "
  1700. "tokens), try increasing it",
  1701. slot.task->n_tokens(), slot.n_ctx),
  1702. ERROR_TYPE_EXCEED_CONTEXT_SIZE);
  1703. slot.release();
  1704. continue;
  1705. }
  1706. if (slot.task->params.cache_prompt) {
  1707. // reuse any previously computed tokens that are common with the new prompt
  1708. n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
  1709. // if there is an alora invoked, don't cache after the invocation start
  1710. if (slot.alora_invocation_start > 0) {
  1711. SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
  1712. n_past = std::min(n_past, slot.alora_invocation_start - 1);
  1713. }
  1714. const auto n_cache_reuse = slot.task->params.n_cache_reuse;
  1715. const bool can_cache_reuse =
  1716. llama_memory_can_shift(llama_get_memory(ctx)) &&
  1717. !slot.prompt.tokens.has_mtmd;
  1718. if (!can_cache_reuse && n_cache_reuse > 0) {
  1719. SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
  1720. }
  1721. // reuse chunks from the cached prompt by shifting their KV cache in the new position
  1722. if (can_cache_reuse && n_cache_reuse > 0) {
  1723. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1724. size_t head_c = n_past; // cache
  1725. size_t head_p = n_past; // current prompt
  1726. if (mctx) {
  1727. // we should never reach this
  1728. GGML_ABORT("not supported by multimodal");
  1729. }
  1730. SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
  1731. while (head_c < slot.prompt.tokens.size() &&
  1732. head_p < input_tokens.size()) {
  1733. size_t n_match = 0;
  1734. while (head_c + n_match < slot.prompt.tokens.size() &&
  1735. head_p + n_match < input_tokens.size() &&
  1736. slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
  1737. n_match++;
  1738. }
  1739. if (n_match >= (size_t) n_cache_reuse) {
  1740. SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
  1741. //for (size_t i = head_p; i < head_p + n_match; i++) {
  1742. // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
  1743. //}
  1744. const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
  1745. llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
  1746. llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
  1747. for (size_t i = 0; i < n_match; i++) {
  1748. slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
  1749. n_past++;
  1750. }
  1751. head_c += n_match;
  1752. head_p += n_match;
  1753. } else {
  1754. head_c += 1;
  1755. }
  1756. }
  1757. SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
  1758. }
  1759. } else {
  1760. // if we don't cache the prompt, we have to remove all previous tokens
  1761. n_past = 0;
  1762. }
  1763. // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
  1764. const auto n_swa = std::max(1, llama_model_n_swa(model));
  1765. // the largest pos_min required for a checkpoint to be useful
  1766. const auto pos_min_thold = std::max(0, n_past - n_swa);
  1767. // note: disallow with mtmd contexts for now
  1768. // https://github.com/ggml-org/llama.cpp/issues/17043
  1769. if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
  1770. const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
  1771. if (pos_min == -1) {
  1772. SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
  1773. GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
  1774. }
  1775. // when the prompt prefix does not match, print the tokens around the mismatch
  1776. // this is useful for debugging prompt caching
  1777. if (slots_debug) {
  1778. const int np0 = std::max<int>(n_past - 4, 0);
  1779. const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
  1780. std::stringstream ss0;
  1781. std::stringstream ss1;
  1782. std::stringstream st0;
  1783. std::stringstream st1;
  1784. ss0 << "old: ... ";
  1785. ss1 << "new: ... ";
  1786. for (int i = np0; i < np1; i++) {
  1787. if (i == n_past) {
  1788. ss0 << " | ";
  1789. ss1 << " | ";
  1790. }
  1791. {
  1792. const auto token = slot.prompt.tokens[i];
  1793. const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
  1794. ss0 << piece;
  1795. st0 << std::setw(8) << token;
  1796. }
  1797. {
  1798. const auto token = slot.task->tokens[i];
  1799. const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
  1800. ss1 << piece;
  1801. st1 << std::setw(8) << token;
  1802. }
  1803. }
  1804. SLT_WRN(slot, "%s\n", ss0.str().c_str());
  1805. SLT_WRN(slot, "%s\n", ss1.str().c_str());
  1806. SLT_WRN(slot, "%s\n", st0.str().c_str());
  1807. SLT_WRN(slot, "%s\n", st1.str().c_str());
  1808. }
  1809. if (pos_min > pos_min_thold) {
  1810. // TODO: support can be added in the future when corresponding vision models get released
  1811. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1812. SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
  1813. // search for a context checkpoint
  1814. const auto it = std::find_if(
  1815. slot.prompt.checkpoints.rbegin(),
  1816. slot.prompt.checkpoints.rend(),
  1817. [&](const auto & cur) {
  1818. // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
  1819. return cur.pos_min < pos_min_thold;
  1820. }
  1821. );
  1822. bool do_reset = it == slot.prompt.checkpoints.rend();
  1823. if (!do_reset) {
  1824. // restore the context checkpoint
  1825. const size_t checkpoint_size = it->data.size();
  1826. const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1827. if (n != checkpoint_size) {
  1828. SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
  1829. do_reset = true;
  1830. //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
  1831. } else {
  1832. n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
  1833. SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
  1834. }
  1835. }
  1836. if (do_reset) {
  1837. SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
  1838. "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
  1839. n_past = 0;
  1840. }
  1841. }
  1842. }
  1843. {
  1844. // erase any checkpoints with pos_min > pos_min_thold
  1845. for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
  1846. const auto & cur = *it;
  1847. if (cur.pos_min > pos_min_thold) {
  1848. SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
  1849. it = slot.prompt.checkpoints.erase(it);
  1850. } else {
  1851. ++it;
  1852. }
  1853. }
  1854. }
  1855. }
  1856. // [TAG_PROMPT_LOGITS]
  1857. if (n_past == slot.task->n_tokens() && n_past > 0) {
  1858. SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
  1859. n_past--;
  1860. SLT_WRN(slot, "n_past was set to %d\n", n_past);
  1861. }
  1862. slot.n_prompt_tokens_cache = n_past;
  1863. slot.n_prompt_tokens_processed = 0;
  1864. slot.prompt.tokens.keep_first(n_past);
  1865. // send initial 0% progress update if needed
  1866. // this is to signal the client that the request has started processing
  1867. if (slot.task->params.stream && slot.task->params.return_progress) {
  1868. send_partial_response(slot, {}, true);
  1869. }
  1870. }
  1871. if (!slot.can_split()) {
  1872. // cannot fit the prompt in the current batch - will try next iter
  1873. if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
  1874. continue;
  1875. }
  1876. }
  1877. // truncate any tokens that are beyond n_past for this slot
  1878. const llama_pos p0 = slot.prompt.tokens.pos_next();
  1879. SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
  1880. if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
  1881. SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
  1882. clear_slot(slot, /*allow_processing=*/true);
  1883. // there is no common part left
  1884. slot.n_prompt_tokens_cache = 0;
  1885. }
  1886. // check if we should process the image
  1887. if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
  1888. // process the image
  1889. size_t n_tokens_out = 0;
  1890. int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
  1891. if (res != 0) {
  1892. SLT_ERR(slot, "failed to process image, res = %d\n", res);
  1893. send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
  1894. slot.release();
  1895. continue;
  1896. }
  1897. slot.n_prompt_tokens_processed += n_tokens_out;
  1898. // add the image chunk to cache
  1899. {
  1900. const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
  1901. slot.prompt.tokens.push_back(chunk.get()); // copy
  1902. }
  1903. }
  1904. // If using an alora, there may be uncached tokens that come
  1905. // before the invocation sequence. When this happens, the
  1906. // tokens before the invocation sequence need to be
  1907. // processed without the adapter in a separate batch, then
  1908. // the adapter needs to be enabled for the remaining tokens.
  1909. if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
  1910. SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
  1911. const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
  1912. GGML_ASSERT(enabled_loras.size() == 1);
  1913. alora_scale = slot.lora[enabled_loras[0]].scale;
  1914. slot.lora[enabled_loras[0]].scale = 0.0f;
  1915. alora_disabled_id = enabled_loras[0];
  1916. }
  1917. bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
  1918. // make checkpoints only for completion tasks
  1919. do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
  1920. // make a checkpoint of the parts of the memory that cannot be rolled back.
  1921. // checkpoints are created only if:
  1922. // - the model uses SWA and we are not using `swa_full`
  1923. // - the model architecture is marked as recurrent or hybrid
  1924. //
  1925. // TODO: try to make this conditional on the context or the memory module, instead of the model type
  1926. do_checkpoint = do_checkpoint && (
  1927. llama_model_is_recurrent(model) ||
  1928. llama_model_is_hybrid(model) ||
  1929. (llama_model_n_swa(model) > 0 && !params_base.swa_full)
  1930. );
  1931. // add prompt tokens for processing in the current batch
  1932. while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
  1933. // get next token to process
  1934. llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
  1935. if (cur_tok == LLAMA_TOKEN_NULL) {
  1936. break; // end of text chunk
  1937. }
  1938. // if this is an alora request with pre-invocation
  1939. // tokens that are not cached, we need to stop filling
  1940. // this batch at those pre-invocation tokens.
  1941. if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
  1942. SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
  1943. break;
  1944. }
  1945. // embedding requires all tokens in the batch to be output
  1946. common_batch_add(batch,
  1947. cur_tok,
  1948. slot.prompt.tokens.pos_next(),
  1949. { slot.id },
  1950. slot.need_embd());
  1951. slot.prompt.tokens.push_back(cur_tok);
  1952. slot.n_prompt_tokens_processed++;
  1953. // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
  1954. if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
  1955. break;
  1956. }
  1957. }
  1958. // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
  1959. SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
  1960. // entire prompt has been processed
  1961. if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
  1962. slot.state = SLOT_STATE_DONE_PROMPT;
  1963. GGML_ASSERT(batch.n_tokens > 0);
  1964. common_sampler_reset(slot.smpl.get());
  1965. // Process all prompt tokens through sampler system
  1966. for (int i = 0; i < slot.task->n_tokens(); ++i) {
  1967. llama_token id = input_tokens[i];
  1968. if (id != LLAMA_TOKEN_NULL) {
  1969. common_sampler_accept(slot.smpl.get(), id, false);
  1970. }
  1971. }
  1972. // extract the logits only for the last token
  1973. batch.logits[batch.n_tokens - 1] = true;
  1974. slot.n_decoded = 0;
  1975. slot.i_batch = batch.n_tokens - 1;
  1976. SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
  1977. const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
  1978. const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
  1979. // no need for empty or small checkpoints
  1980. do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
  1981. // no need to create checkpoints that are too close together
  1982. do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
  1983. if (do_checkpoint) {
  1984. while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
  1985. // make room for the new checkpoint, if needed
  1986. const auto & cur = slot.prompt.checkpoints.front();
  1987. SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
  1988. cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
  1989. slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
  1990. }
  1991. const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1992. auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
  1993. /*.pos_min = */ pos_min,
  1994. /*.pos_max = */ pos_max,
  1995. /*.data = */ std::vector<uint8_t>(checkpoint_size),
  1996. });
  1997. llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1998. SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
  1999. (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
  2000. }
  2001. }
  2002. }
  2003. if (!slot_batched) {
  2004. slot_batched = &slot;
  2005. }
  2006. if (batch.n_tokens >= n_batch) {
  2007. break;
  2008. }
  2009. }
  2010. }
  2011. if (batch.n_tokens == 0) {
  2012. SRV_WRN("%s", "no tokens to decode\n");
  2013. return;
  2014. }
  2015. SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
  2016. if (slot_batched) {
  2017. // apply lora, only need to do it once per batch
  2018. common_set_adapter_lora(ctx, slot_batched->lora);
  2019. // if the lora is temporarily disabled for an alora, re-enable it
  2020. // for next time
  2021. if (alora_scale > 0.0f) {
  2022. SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
  2023. slot_batched->lora[alora_disabled_id].scale = alora_scale;
  2024. }
  2025. llama_set_embeddings(ctx, slot_batched->need_embd());
  2026. }
  2027. int32_t i_next = 0;
  2028. // process the created batch of tokens
  2029. for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
  2030. const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
  2031. llama_batch batch_view = {
  2032. n_tokens,
  2033. batch.token + i,
  2034. nullptr,
  2035. batch.pos + i,
  2036. batch.n_seq_id + i,
  2037. batch.seq_id + i,
  2038. batch.logits + i,
  2039. };
  2040. const int ret = llama_decode(ctx, batch_view);
  2041. metrics.on_decoded(slots);
  2042. if (ret != 0) {
  2043. {
  2044. std::string err;
  2045. if (n_batch == 1 && ret == 1) {
  2046. // TODO: try to terminate only the largest active slot/sequence and continue with the rest
  2047. // need to remove the tokens from the current batch too
  2048. err = "Context size has been exceeded.";
  2049. }
  2050. if (ret == -1) {
  2051. err = "Invalid input batch.";
  2052. }
  2053. if (ret < -1) {
  2054. // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
  2055. err = "Compute error.";
  2056. }
  2057. // TODO: handle ret == 2 (abort) when we start aborting
  2058. if (!err.empty()) {
  2059. SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
  2060. for (auto & slot : slots) {
  2061. if (slot.is_processing()) {
  2062. send_error(slot, err);
  2063. slot.release();
  2064. // note: it's complicated to keep track of how much of the current batch has been
  2065. // processed before the error occurred, so we simply clear the entire context
  2066. clear_slot(slot);
  2067. }
  2068. }
  2069. break;
  2070. }
  2071. }
  2072. // retry with half the batch size to try to find a free slot in the KV cache
  2073. if (!try_clear_idle_slots()) {
  2074. n_batch /= 2;
  2075. }
  2076. SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
  2077. continue; // continue loop of n_batch
  2078. }
  2079. // move the head of the batch forward with the number of tokens we just processed
  2080. i_next = i + n_tokens;
  2081. // on successful decode, restore the original batch size
  2082. n_batch = llama_n_batch(ctx);
  2083. // technically, measuring the time here excludes the sampling time for the last batch
  2084. // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
  2085. const int64_t t_current = ggml_time_us();
  2086. for (auto & slot : slots) {
  2087. // may need to copy state to other slots
  2088. if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
  2089. std::vector<server_slot *> child_slots;
  2090. for (auto & other : slots) {
  2091. if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
  2092. child_slots.push_back(&other);
  2093. }
  2094. }
  2095. // we can only proceed if all child slots are having the correct tasks
  2096. if (child_slots.size() == slot.task->n_children) {
  2097. // copy state to the child slots
  2098. for (auto & child : child_slots) {
  2099. SLT_INF(slot, "copying state to child %d\n", child->id);
  2100. slot.copy_state_to(*child);
  2101. child->state = SLOT_STATE_DONE_PROMPT;
  2102. }
  2103. }
  2104. }
  2105. // optionally send prompt processing progress
  2106. if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
  2107. if (slot.task->params.stream && slot.task->params.return_progress) {
  2108. send_partial_response(slot, {}, true);
  2109. }
  2110. }
  2111. if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
  2112. continue; // continue loop of slots
  2113. }
  2114. if (slot.state == SLOT_STATE_DONE_PROMPT) {
  2115. if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
  2116. // prompt evaluated for embedding
  2117. send_embedding(slot, batch_view);
  2118. slot.release();
  2119. slot.i_batch = -1;
  2120. continue; // continue loop of slots
  2121. }
  2122. if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
  2123. send_rerank(slot, batch_view);
  2124. slot.release();
  2125. slot.i_batch = -1;
  2126. continue; // continue loop of slots
  2127. }
  2128. // prompt evaluated for next-token prediction
  2129. slot.state = SLOT_STATE_GENERATING;
  2130. } else if (slot.state != SLOT_STATE_GENERATING) {
  2131. continue; // continue loop of slots
  2132. }
  2133. if (slot.i_batch_dft.size() > 0) {
  2134. continue; // sample using speculative decoding
  2135. }
  2136. const int tok_idx = slot.i_batch - i;
  2137. llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
  2138. slot.i_batch = -1;
  2139. common_sampler_accept(slot.smpl.get(), id, true);
  2140. slot.n_decoded += 1;
  2141. if (slot.n_decoded == 1) {
  2142. slot.t_start_generation = t_current;
  2143. slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
  2144. metrics.on_prompt_eval(slot);
  2145. }
  2146. slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
  2147. completion_token_output result;
  2148. result.tok = id;
  2149. result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
  2150. result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
  2151. if (slot.task->params.sampling.n_probs > 0) {
  2152. populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
  2153. }
  2154. if (!process_token(result, slot)) {
  2155. // release slot because of stop condition
  2156. slot.print_timings();
  2157. send_final_response(slot);
  2158. metrics.on_prediction(slot);
  2159. slot.release();
  2160. continue;
  2161. }
  2162. }
  2163. // speculative decoding - main model sample and accept
  2164. for (auto & slot : slots) {
  2165. if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
  2166. continue;
  2167. }
  2168. size_t n_draft = slot.drafted.size();
  2169. // the accepted tokens from the speculation
  2170. const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
  2171. slot.i_batch_dft.clear();
  2172. slot.drafted.clear();
  2173. slot.n_decoded += ids.size();
  2174. slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
  2175. // update how many tokens out of those tested were accepted
  2176. slot.n_draft_accepted += ids.size() - 1;
  2177. // rollback to the state before sampling the draft tokens
  2178. slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
  2179. // add accepted tokens to the prompt
  2180. slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
  2181. slot.sampled = ids.back(); // last accepted token
  2182. llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
  2183. for (size_t i = 0; i < ids.size(); ++i) {
  2184. completion_token_output result;
  2185. result.tok = ids[i];
  2186. result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
  2187. result.prob = 1.0f; // set later
  2188. // TODO: set result.probs
  2189. if (!process_token(result, slot)) {
  2190. slot.print_timings();
  2191. send_final_response(slot);
  2192. metrics.on_prediction(slot);
  2193. slot.release();
  2194. break;
  2195. }
  2196. }
  2197. SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
  2198. }
  2199. }
  2200. SRV_DBG("%s", "run slots completed\n");
  2201. }
  2202. int get_slot_n_ctx() {
  2203. return slots.back().n_ctx;
  2204. }
  2205. server_response_reader get_response_reader() {
  2206. return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
  2207. }
  2208. };
  2209. //
  2210. // server_context (public API)
  2211. //
  2212. server_context::server_context() : impl(new server_context_impl()) {}
  2213. server_context::~server_context() = default;
  2214. bool server_context::load_model(const common_params & params) {
  2215. return impl->load_model(params);
  2216. }
  2217. void server_context::start_loop() {
  2218. auto & params = impl->params_base;
  2219. impl->queue_tasks.start_loop(params.sleep_idle_seconds * 1000);
  2220. }
  2221. void server_context::terminate() {
  2222. impl->queue_tasks.terminate();
  2223. }
  2224. llama_context * server_context::get_llama_context() const {
  2225. return impl->ctx;
  2226. }
  2227. server_response_reader server_context::get_response_reader() {
  2228. return impl->get_response_reader();
  2229. }
  2230. server_context_meta server_context::get_meta() const {
  2231. auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use");
  2232. auto bos_id = llama_vocab_bos(impl->vocab);
  2233. auto eos_id = llama_vocab_eos(impl->vocab);
  2234. auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
  2235. auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : "";
  2236. return server_context_meta {
  2237. /* build_info */ build_info,
  2238. /* model_name */ impl->model_name,
  2239. /* model_path */ impl->params_base.model.path,
  2240. /* has_mtmd */ impl->mctx != nullptr,
  2241. /* has_inp_image */ impl->oai_parser_opt.allow_image,
  2242. /* has_inp_audio */ impl->oai_parser_opt.allow_audio,
  2243. /* json_webui_settings */ impl->json_webui_settings,
  2244. /* slot_n_ctx */ impl->get_slot_n_ctx(),
  2245. /* pooling_type */ llama_pooling_type(impl->ctx),
  2246. /* chat_template */ common_chat_templates_source(impl->chat_templates.get()),
  2247. /* chat_template_tool_use */ tool_use_src ? tool_use_src : "",
  2248. /* bos_token_str */ bos_token_str,
  2249. /* eos_token_str */ eos_token_str,
  2250. /* fim_pre_token */ llama_vocab_fim_pre(impl->vocab),
  2251. /* fim_sub_token */ llama_vocab_fim_suf(impl->vocab),
  2252. /* fim_mid_token */ llama_vocab_fim_mid(impl->vocab),
  2253. /* model_vocab_type */ llama_vocab_type(impl->vocab),
  2254. /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
  2255. /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
  2256. /* model_n_embd_inp */ llama_model_n_embd(impl->model),
  2257. /* model_n_params */ llama_model_n_params(impl->model),
  2258. /* model_size */ llama_model_size(impl->model),
  2259. };
  2260. }
  2261. // generator-like API for HTTP response generation
  2262. // may have bypass_sleep = true if the task does not use ctx_server
  2263. struct server_res_generator : server_http_res {
  2264. server_response_reader rd;
  2265. server_res_generator(server_queue & queue_tasks, server_response & queue_results, int sleep_idle_seconds, bool bypass_sleep = false)
  2266. : rd(queue_tasks, queue_results, HTTP_POLLING_SECONDS) {
  2267. // fast path in case sleeping is disabled
  2268. bypass_sleep |= sleep_idle_seconds < 0;
  2269. if (!bypass_sleep) {
  2270. queue_tasks.wait_until_no_sleep();
  2271. }
  2272. }
  2273. void ok(const json & response_data) {
  2274. status = 200;
  2275. data = safe_json_to_str(response_data);
  2276. }
  2277. void error(const json & error_data) {
  2278. status = json_value(error_data, "code", 500);
  2279. data = safe_json_to_str({{ "error", error_data }});
  2280. }
  2281. };
  2282. //
  2283. // server_routes
  2284. //
  2285. std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
  2286. const server_http_req & req,
  2287. server_task_type type,
  2288. const json & data,
  2289. const std::vector<raw_buffer> & files,
  2290. task_response_type res_type) {
  2291. GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
  2292. auto res = create_response();
  2293. auto completion_id = gen_chatcmplid();
  2294. auto & rd = res->rd;
  2295. try {
  2296. std::vector<server_task> tasks;
  2297. const auto & prompt = data.at("prompt");
  2298. // TODO: this log can become very long, put it behind a flag or think about a more compact format
  2299. //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
  2300. // process prompt
  2301. std::vector<server_tokens> inputs;
  2302. if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
  2303. // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
  2304. inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
  2305. } else {
  2306. // Everything else, including multimodal completions.
  2307. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  2308. }
  2309. tasks.reserve(inputs.size());
  2310. for (size_t i = 0; i < inputs.size(); i++) {
  2311. server_task task = server_task(type);
  2312. task.id = rd.get_new_id();
  2313. task.tokens = std::move(inputs[i]);
  2314. task.params = server_task::params_from_json_cmpl(
  2315. ctx_server.vocab,
  2316. params,
  2317. meta->slot_n_ctx,
  2318. data);
  2319. task.id_slot = json_value(data, "id_slot", -1);
  2320. // OAI-compat
  2321. task.params.res_type = res_type;
  2322. task.params.oaicompat_cmpl_id = completion_id;
  2323. task.params.oaicompat_model = meta->model_name;
  2324. if (task.params.n_cmpl > 1) {
  2325. task.n_children = task.params.n_cmpl - 1;
  2326. for (size_t j = 0; j < task.n_children; j++) {
  2327. server_task child = task.create_child(task.id, rd.get_new_id());
  2328. // use different sampling seed for each child
  2329. // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
  2330. if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
  2331. child.params.sampling.seed += j + 1;
  2332. }
  2333. tasks.push_back(std::move(child));
  2334. }
  2335. }
  2336. tasks.push_back(std::move(task));
  2337. }
  2338. rd.post_tasks(std::move(tasks));
  2339. } catch (const std::exception & e) {
  2340. res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
  2341. return res;
  2342. }
  2343. bool stream = json_value(data, "stream", false);
  2344. if (!stream) {
  2345. // non-stream, wait for the results
  2346. auto all_results = rd.wait_for_all(req.should_stop);
  2347. if (all_results.is_terminated) {
  2348. return res; // connection is closed
  2349. } else if (all_results.error) {
  2350. res->error(all_results.error->to_json());
  2351. return res;
  2352. } else {
  2353. json arr = json::array();
  2354. for (auto & res : all_results.results) {
  2355. GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
  2356. arr.push_back(res->to_json());
  2357. }
  2358. GGML_ASSERT(!arr.empty() && "empty results");
  2359. if (arr.size() == 1) {
  2360. // if single request, return single object instead of array
  2361. res->ok(arr[0]);
  2362. } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
  2363. // if multiple results in OAI format, we need to re-format them
  2364. json & choices = arr[0]["choices"];
  2365. for (size_t i = 1; i < arr.size(); i++) {
  2366. choices.push_back(std::move(arr[i]["choices"][0]));
  2367. }
  2368. res->ok(arr[0]);
  2369. } else {
  2370. // multi-results, non-OAI compat
  2371. res->ok(arr);
  2372. }
  2373. }
  2374. } else {
  2375. // in streaming mode, the first error must be treated as non-stream response
  2376. // this is to match the OAI API behavior
  2377. // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
  2378. auto first_result = rd.next(req.should_stop);
  2379. if (first_result == nullptr) {
  2380. GGML_ASSERT(req.should_stop());
  2381. return res; // connection is closed
  2382. }
  2383. if (first_result->is_error()) {
  2384. res->error(first_result->to_json());
  2385. return res;
  2386. }
  2387. GGML_ASSERT(
  2388. dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr ||
  2389. dynamic_cast<server_task_result_cmpl_final*> (first_result.get()) != nullptr
  2390. );
  2391. // next responses are streamed
  2392. // to be sent immediately
  2393. json first_result_json = first_result->to_json();
  2394. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2395. res->data = format_anthropic_sse(first_result_json);
  2396. } else {
  2397. res->data = format_oai_sse(first_result_json);
  2398. }
  2399. res->status = 200;
  2400. res->content_type = "text/event-stream";
  2401. res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool {
  2402. static auto format_error = [](task_response_type res_type, const json & res_json) {
  2403. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2404. return format_anthropic_sse({
  2405. {"event", "error"},
  2406. {"data", res_json},
  2407. });
  2408. } else {
  2409. return format_oai_sse(json {{ "error", res_json }});
  2410. }
  2411. };
  2412. try {
  2413. if (req.should_stop()) {
  2414. SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
  2415. return false; // should_stop condition met
  2416. }
  2417. if (!res_this->data.empty()) {
  2418. // flush the first chunk
  2419. output = std::move(res_this->data);
  2420. res_this->data.clear();
  2421. return true;
  2422. }
  2423. server_response_reader & rd = res_this->rd;
  2424. // check if there is more data
  2425. if (!rd.has_next()) {
  2426. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2427. // Anthropic doesn't send [DONE], message_stop was already sent
  2428. output = "";
  2429. } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
  2430. output = "data: [DONE]\n\n";
  2431. } else {
  2432. output = "";
  2433. }
  2434. SRV_DBG("%s", "all results received, terminating stream\n");
  2435. return false; // no more data, terminate
  2436. }
  2437. // receive subsequent results
  2438. auto result = rd.next(req.should_stop);
  2439. if (result == nullptr) {
  2440. SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
  2441. GGML_ASSERT(req.should_stop());
  2442. return false; // should_stop condition met
  2443. }
  2444. // send the results
  2445. if (result->is_error()) {
  2446. json res_json = result->to_json();
  2447. output = format_error(res_type, res_json);
  2448. SRV_DBG("%s", "error received during streaming, terminating stream\n");
  2449. return false; // terminate on error
  2450. } else {
  2451. GGML_ASSERT(
  2452. dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
  2453. || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
  2454. );
  2455. json res_json = result->to_json();
  2456. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2457. output = format_anthropic_sse(res_json);
  2458. } else {
  2459. output = format_oai_sse(res_json);
  2460. }
  2461. }
  2462. // has next data, continue
  2463. return true;
  2464. } catch (const std::exception & e) {
  2465. json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
  2466. output = format_error(res_type, error_json);
  2467. // terminate on exception
  2468. return false;
  2469. }
  2470. };
  2471. }
  2472. return res;
  2473. }
  2474. std::unique_ptr<server_res_generator> server_routes::create_response(bool bypass_sleep) {
  2475. return std::make_unique<server_res_generator>(queue_tasks, queue_results, params.sleep_idle_seconds, bypass_sleep);
  2476. }
  2477. server_routes::server_routes(const common_params & params, server_context & ctx_server)
  2478. : params(params),
  2479. ctx_server(*ctx_server.impl),
  2480. queue_tasks(ctx_server.impl->queue_tasks),
  2481. queue_results(ctx_server.impl->queue_results) {
  2482. init_routes();
  2483. }
  2484. void server_routes::init_routes() {
  2485. // IMPORTANT: all lambda functions must start with create_response()
  2486. // this is to ensure that the server_res_generator can handle sleeping case correctly
  2487. this->get_health = [this](const server_http_req &) {
  2488. // error and loading states are handled by middleware
  2489. auto res = create_response(true);
  2490. // this endpoint can be accessed during sleeping
  2491. // the next LOC is to avoid someone accidentally use ctx_server
  2492. bool server_ctx; // do NOT delete this line
  2493. GGML_UNUSED(server_ctx);
  2494. res->ok({{"status", "ok"}});
  2495. return res;
  2496. };
  2497. this->get_metrics = [this](const server_http_req & req) {
  2498. auto res = create_response();
  2499. if (!params.endpoint_metrics) {
  2500. res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
  2501. return res;
  2502. }
  2503. // request slots data using task queue
  2504. {
  2505. server_task task(SERVER_TASK_TYPE_METRICS);
  2506. task.id = res->rd.get_new_id();
  2507. res->rd.post_task(std::move(task), true); // high-priority task
  2508. }
  2509. // get the result
  2510. auto result = res->rd.next(req.should_stop);
  2511. if (!result) {
  2512. // connection was closed
  2513. GGML_ASSERT(req.should_stop());
  2514. return res;
  2515. }
  2516. if (result->is_error()) {
  2517. res->error(result->to_json());
  2518. return res;
  2519. }
  2520. // TODO: get rid of this dynamic_cast
  2521. auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
  2522. GGML_ASSERT(res_task != nullptr);
  2523. // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
  2524. json all_metrics_def = json {
  2525. {"counter", {{
  2526. {"name", "prompt_tokens_total"},
  2527. {"help", "Number of prompt tokens processed."},
  2528. {"value", (uint64_t) res_task->n_prompt_tokens_processed_total}
  2529. }, {
  2530. {"name", "prompt_seconds_total"},
  2531. {"help", "Prompt process time"},
  2532. {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3}
  2533. }, {
  2534. {"name", "tokens_predicted_total"},
  2535. {"help", "Number of generation tokens processed."},
  2536. {"value", (uint64_t) res_task->n_tokens_predicted_total}
  2537. }, {
  2538. {"name", "tokens_predicted_seconds_total"},
  2539. {"help", "Predict process time"},
  2540. {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3}
  2541. }, {
  2542. {"name", "n_decode_total"},
  2543. {"help", "Total number of llama_decode() calls"},
  2544. {"value", res_task->n_decode_total}
  2545. }, {
  2546. {"name", "n_tokens_max"},
  2547. {"help", "Largest observed n_tokens."},
  2548. {"value", res_task->n_tokens_max}
  2549. }, {
  2550. {"name", "n_busy_slots_per_decode"},
  2551. {"help", "Average number of busy slots per llama_decode() call"},
  2552. {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
  2553. }}},
  2554. {"gauge", {{
  2555. {"name", "prompt_tokens_seconds"},
  2556. {"help", "Average prompt throughput in tokens/s."},
  2557. {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
  2558. },{
  2559. {"name", "predicted_tokens_seconds"},
  2560. {"help", "Average generation throughput in tokens/s."},
  2561. {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
  2562. },{
  2563. {"name", "requests_processing"},
  2564. {"help", "Number of requests processing."},
  2565. {"value", (uint64_t) res_task->n_processing_slots}
  2566. },{
  2567. {"name", "requests_deferred"},
  2568. {"help", "Number of requests deferred."},
  2569. {"value", (uint64_t) res_task->n_tasks_deferred}
  2570. }}}
  2571. };
  2572. std::stringstream prometheus;
  2573. for (const auto & el : all_metrics_def.items()) {
  2574. const auto & type = el.key();
  2575. const auto & metrics_def = el.value();
  2576. for (const auto & metric_def : metrics_def) {
  2577. const std::string name = metric_def.at("name");
  2578. const std::string help = metric_def.at("help");
  2579. auto value = json_value(metric_def, "value", 0.);
  2580. prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
  2581. << "# TYPE llamacpp:" << name << " " << type << "\n"
  2582. << "llamacpp:" << name << " " << value << "\n";
  2583. }
  2584. }
  2585. res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
  2586. res->content_type = "text/plain; version=0.0.4";
  2587. res->status = 200;
  2588. res->data = prometheus.str();
  2589. return res;
  2590. };
  2591. this->get_slots = [this](const server_http_req & req) {
  2592. auto res = create_response();
  2593. if (!params.endpoint_slots) {
  2594. res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
  2595. return res;
  2596. }
  2597. // request slots data using task queue
  2598. {
  2599. server_task task(SERVER_TASK_TYPE_METRICS);
  2600. task.id = res->rd.get_new_id();
  2601. res->rd.post_task(std::move(task), true); // high-priority task
  2602. }
  2603. // get the result
  2604. auto result = res->rd.next(req.should_stop);
  2605. if (!result) {
  2606. // connection was closed
  2607. GGML_ASSERT(req.should_stop());
  2608. return res;
  2609. }
  2610. if (result->is_error()) {
  2611. res->error(result->to_json());
  2612. return res;
  2613. }
  2614. // TODO: get rid of this dynamic_cast
  2615. auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
  2616. GGML_ASSERT(res_task != nullptr);
  2617. // optionally return "fail_on_no_slot" error
  2618. if (!req.get_param("fail_on_no_slot").empty()) {
  2619. if (res_task->n_idle_slots == 0) {
  2620. res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
  2621. return res;
  2622. }
  2623. }
  2624. res->ok(res_task->slots_data);
  2625. return res;
  2626. };
  2627. this->post_slots = [this](const server_http_req & req) {
  2628. auto res = create_response();
  2629. if (params.slot_save_path.empty()) {
  2630. res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
  2631. return res;
  2632. }
  2633. std::string id_slot_str = req.get_param("id_slot");
  2634. int id_slot;
  2635. try {
  2636. id_slot = std::stoi(id_slot_str);
  2637. } catch (const std::exception &) {
  2638. res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
  2639. return res;
  2640. }
  2641. std::string action = req.get_param("action");
  2642. if (action == "save") {
  2643. return handle_slots_save(req, id_slot);
  2644. } else if (action == "restore") {
  2645. return handle_slots_restore(req, id_slot);
  2646. } else if (action == "erase") {
  2647. return handle_slots_erase(req, id_slot);
  2648. } else {
  2649. res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
  2650. return res;
  2651. }
  2652. };
  2653. this->get_props = [this](const server_http_req &) {
  2654. auto res = create_response(true);
  2655. // this endpoint can be accessed during sleeping
  2656. // the next LOC is to avoid someone accidentally use ctx_server
  2657. bool server_ctx; // do NOT delete this line
  2658. GGML_UNUSED(server_ctx);
  2659. task_params tparams;
  2660. tparams.sampling = params.sampling;
  2661. json default_generation_settings_for_props = json {
  2662. { "params", tparams.to_json(true) },
  2663. { "n_ctx", meta->slot_n_ctx },
  2664. };
  2665. json props = {
  2666. { "default_generation_settings", default_generation_settings_for_props },
  2667. { "total_slots", params.n_parallel },
  2668. { "model_alias", meta->model_name },
  2669. { "model_path", meta->model_path },
  2670. { "modalities", json {
  2671. {"vision", meta->has_inp_image},
  2672. {"audio", meta->has_inp_audio},
  2673. } },
  2674. { "endpoint_slots", params.endpoint_slots },
  2675. { "endpoint_props", params.endpoint_props },
  2676. { "endpoint_metrics", params.endpoint_metrics },
  2677. { "webui", params.webui },
  2678. { "webui_settings", meta->json_webui_settings },
  2679. { "chat_template", meta->chat_template },
  2680. { "bos_token", meta->bos_token_str },
  2681. { "eos_token", meta->eos_token_str },
  2682. { "build_info", meta->build_info },
  2683. { "is_sleeping", queue_tasks.is_sleeping() },
  2684. };
  2685. if (params.use_jinja) {
  2686. if (!meta->chat_template_tool_use.empty()) {
  2687. props["chat_template_tool_use"] = meta->chat_template_tool_use;
  2688. }
  2689. }
  2690. res->ok(props);
  2691. return res;
  2692. };
  2693. this->post_props = [this](const server_http_req &) {
  2694. auto res = create_response();
  2695. if (!params.endpoint_props) {
  2696. res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
  2697. return res;
  2698. }
  2699. // update any props here
  2700. res->ok({{ "success", true }});
  2701. return res;
  2702. };
  2703. this->get_api_show = [this](const server_http_req &) {
  2704. auto res = create_response();
  2705. json data = {
  2706. {
  2707. "model_info", {
  2708. { "llama.context_length", meta->slot_n_ctx },
  2709. }
  2710. },
  2711. {"modelfile", ""},
  2712. {"parameters", ""},
  2713. {"template", meta->chat_template},
  2714. {"details", {
  2715. {"parent_model", ""},
  2716. {"format", "gguf"},
  2717. {"family", ""},
  2718. {"families", {""}},
  2719. {"parameter_size", ""},
  2720. {"quantization_level", ""}
  2721. }},
  2722. {"model_info", ""},
  2723. {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
  2724. };
  2725. res->ok(data);
  2726. return res;
  2727. };
  2728. this->post_infill = [this](const server_http_req & req) {
  2729. auto res = create_response();
  2730. // check model compatibility
  2731. std::string err;
  2732. if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2733. err += "prefix token is missing. ";
  2734. }
  2735. if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2736. err += "suffix token is missing. ";
  2737. }
  2738. if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2739. err += "middle token is missing. ";
  2740. }
  2741. if (!err.empty()) {
  2742. res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
  2743. return res;
  2744. }
  2745. // validate input
  2746. json data = json::parse(req.body);
  2747. if (data.contains("prompt") && !data.at("prompt").is_string()) {
  2748. // prompt is optional
  2749. res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
  2750. }
  2751. if (!data.contains("input_prefix")) {
  2752. res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
  2753. }
  2754. if (!data.contains("input_suffix")) {
  2755. res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
  2756. }
  2757. if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
  2758. // input_extra is optional
  2759. res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
  2760. return res;
  2761. }
  2762. json input_extra = json_value(data, "input_extra", json::array());
  2763. for (const auto & chunk : input_extra) {
  2764. // { "text": string, "filename": string }
  2765. if (!chunk.contains("text") || !chunk.at("text").is_string()) {
  2766. res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
  2767. return res;
  2768. }
  2769. // filename is optional
  2770. if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
  2771. res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
  2772. return res;
  2773. }
  2774. }
  2775. data["input_extra"] = input_extra; // default to empty array if it's not exist
  2776. std::string prompt = json_value(data, "prompt", std::string());
  2777. std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
  2778. SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
  2779. data["prompt"] = format_prompt_infill(
  2780. ctx_server.vocab,
  2781. data.at("input_prefix"),
  2782. data.at("input_suffix"),
  2783. data.at("input_extra"),
  2784. params.n_batch,
  2785. params.n_predict,
  2786. meta->slot_n_ctx,
  2787. params.spm_infill,
  2788. tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
  2789. );
  2790. std::vector<raw_buffer> files; // dummy
  2791. return handle_completions_impl(
  2792. req,
  2793. SERVER_TASK_TYPE_INFILL,
  2794. data,
  2795. files,
  2796. TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
  2797. };
  2798. this->post_completions = [this](const server_http_req & req) {
  2799. auto res = create_response();
  2800. std::vector<raw_buffer> files; // dummy
  2801. const json body = json::parse(req.body);
  2802. return handle_completions_impl(
  2803. req,
  2804. SERVER_TASK_TYPE_COMPLETION,
  2805. body,
  2806. files,
  2807. TASK_RESPONSE_TYPE_NONE);
  2808. };
  2809. this->post_completions_oai = [this](const server_http_req & req) {
  2810. auto res = create_response();
  2811. std::vector<raw_buffer> files; // dummy
  2812. const json body = json::parse(req.body);
  2813. return handle_completions_impl(
  2814. req,
  2815. SERVER_TASK_TYPE_COMPLETION,
  2816. body,
  2817. files,
  2818. TASK_RESPONSE_TYPE_OAI_CMPL);
  2819. };
  2820. this->post_chat_completions = [this](const server_http_req & req) {
  2821. auto res = create_response();
  2822. std::vector<raw_buffer> files;
  2823. json body = json::parse(req.body);
  2824. json body_parsed = oaicompat_chat_params_parse(
  2825. body,
  2826. ctx_server.oai_parser_opt,
  2827. files);
  2828. return handle_completions_impl(
  2829. req,
  2830. SERVER_TASK_TYPE_COMPLETION,
  2831. body_parsed,
  2832. files,
  2833. TASK_RESPONSE_TYPE_OAI_CHAT);
  2834. };
  2835. this->post_anthropic_messages = [this](const server_http_req & req) {
  2836. auto res = create_response();
  2837. std::vector<raw_buffer> files;
  2838. json body = convert_anthropic_to_oai(json::parse(req.body));
  2839. json body_parsed = oaicompat_chat_params_parse(
  2840. body,
  2841. ctx_server.oai_parser_opt,
  2842. files);
  2843. return handle_completions_impl(
  2844. req,
  2845. SERVER_TASK_TYPE_COMPLETION,
  2846. body_parsed,
  2847. files,
  2848. TASK_RESPONSE_TYPE_ANTHROPIC);
  2849. };
  2850. this->post_anthropic_count_tokens = [this](const server_http_req & req) {
  2851. auto res = create_response();
  2852. std::vector<raw_buffer> files;
  2853. json body = convert_anthropic_to_oai(json::parse(req.body));
  2854. json body_parsed = oaicompat_chat_params_parse(
  2855. body,
  2856. ctx_server.oai_parser_opt,
  2857. files);
  2858. json prompt = body_parsed.at("prompt");
  2859. llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
  2860. res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
  2861. return res;
  2862. };
  2863. // same with handle_chat_completions, but without inference part
  2864. this->post_apply_template = [this](const server_http_req & req) {
  2865. auto res = create_response();
  2866. std::vector<raw_buffer> files; // dummy, unused
  2867. json body = json::parse(req.body);
  2868. json data = oaicompat_chat_params_parse(
  2869. body,
  2870. ctx_server.oai_parser_opt,
  2871. files);
  2872. res->ok({{ "prompt", std::move(data.at("prompt")) }});
  2873. return res;
  2874. };
  2875. this->get_models = [this](const server_http_req &) {
  2876. auto res = create_response(true);
  2877. // this endpoint can be accessed during sleeping
  2878. // the next LOC is to avoid someone accidentally use ctx_server
  2879. bool server_ctx; // do NOT delete this line
  2880. GGML_UNUSED(server_ctx);
  2881. json models = {
  2882. {"models", {
  2883. {
  2884. {"name", meta->model_name},
  2885. {"model", meta->model_name},
  2886. {"modified_at", ""},
  2887. {"size", ""},
  2888. {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
  2889. {"type", "model"},
  2890. {"description", ""},
  2891. {"tags", {""}},
  2892. {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
  2893. {"parameters", ""},
  2894. {"details", {
  2895. {"parent_model", ""},
  2896. {"format", "gguf"},
  2897. {"family", ""},
  2898. {"families", {""}},
  2899. {"parameter_size", ""},
  2900. {"quantization_level", ""}
  2901. }}
  2902. }
  2903. }},
  2904. {"object", "list"},
  2905. {"data", {
  2906. {
  2907. {"id", meta->model_name},
  2908. {"object", "model"},
  2909. {"created", std::time(0)},
  2910. {"owned_by", "llamacpp"},
  2911. {"meta", {
  2912. {"vocab_type", meta->model_vocab_type},
  2913. {"n_vocab", meta->model_vocab_n_tokens},
  2914. {"n_ctx_train", meta->model_n_ctx_train},
  2915. {"n_embd", meta->model_n_embd_inp},
  2916. {"n_params", meta->model_n_params},
  2917. {"size", meta->model_size},
  2918. }},
  2919. },
  2920. }}
  2921. };
  2922. res->ok(models);
  2923. return res;
  2924. };
  2925. this->post_tokenize = [this](const server_http_req & req) {
  2926. auto res = create_response();
  2927. const json body = json::parse(req.body);
  2928. json tokens_response = json::array();
  2929. if (body.count("content") != 0) {
  2930. const bool add_special = json_value(body, "add_special", false);
  2931. const bool parse_special = json_value(body, "parse_special", true);
  2932. const bool with_pieces = json_value(body, "with_pieces", false);
  2933. llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special);
  2934. if (with_pieces) {
  2935. for (const auto& token : tokens) {
  2936. std::string piece = common_token_to_piece(ctx_server.vocab, token);
  2937. json piece_json;
  2938. // Check if the piece is valid UTF-8
  2939. if (is_valid_utf8(piece)) {
  2940. piece_json = piece;
  2941. } else {
  2942. // If not valid UTF-8, store as array of byte values
  2943. piece_json = json::array();
  2944. for (unsigned char c : piece) {
  2945. piece_json.push_back(static_cast<int>(c));
  2946. }
  2947. }
  2948. tokens_response.push_back({
  2949. {"id", token},
  2950. {"piece", piece_json}
  2951. });
  2952. }
  2953. } else {
  2954. tokens_response = tokens;
  2955. }
  2956. }
  2957. res->ok(json{{"tokens", std::move(tokens_response)}});
  2958. return res;
  2959. };
  2960. this->post_detokenize = [this](const server_http_req & req) {
  2961. auto res = create_response();
  2962. const json body = json::parse(req.body);
  2963. std::string content;
  2964. if (body.count("tokens") != 0) {
  2965. const llama_tokens tokens = body.at("tokens");
  2966. content = tokens_to_str(ctx_server.vocab, tokens);
  2967. }
  2968. res->ok(json{{"content", std::move(content)}});
  2969. return res;
  2970. };
  2971. this->post_embeddings = [this](const server_http_req & req) {
  2972. return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
  2973. };
  2974. this->post_embeddings_oai = [this](const server_http_req & req) {
  2975. return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
  2976. };
  2977. this->post_rerank = [this](const server_http_req & req) {
  2978. auto res = create_response();
  2979. if (!params.embedding || params.pooling_type != LLAMA_POOLING_TYPE_RANK) {
  2980. res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
  2981. return res;
  2982. }
  2983. const json body = json::parse(req.body);
  2984. // if true, use TEI API format, otherwise use Jina API format
  2985. // Jina: https://jina.ai/reranker/
  2986. // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
  2987. bool is_tei_format = body.contains("texts");
  2988. json query;
  2989. if (body.count("query") == 1) {
  2990. query = body.at("query");
  2991. if (!query.is_string()) {
  2992. res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
  2993. return res;
  2994. }
  2995. } else {
  2996. res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
  2997. return res;
  2998. }
  2999. std::vector<std::string> documents = json_value(body, "documents",
  3000. json_value(body, "texts", std::vector<std::string>()));
  3001. if (documents.empty()) {
  3002. res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
  3003. return res;
  3004. }
  3005. int top_n = json_value(body, "top_n", (int)documents.size());
  3006. // create and queue the task
  3007. json responses = json::array();
  3008. auto & rd = res->rd;
  3009. {
  3010. std::vector<server_task> tasks;
  3011. tasks.reserve(documents.size());
  3012. for (size_t i = 0; i < documents.size(); i++) {
  3013. auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
  3014. server_task task = server_task(SERVER_TASK_TYPE_RERANK);
  3015. task.id = rd.get_new_id();
  3016. task.tokens = std::move(tmp);
  3017. tasks.push_back(std::move(task));
  3018. }
  3019. rd.post_tasks(std::move(tasks));
  3020. }
  3021. // wait for the results
  3022. auto all_results = rd.wait_for_all(req.should_stop);
  3023. // collect results
  3024. if (all_results.is_terminated) {
  3025. return res; // connection is closed
  3026. } else if (all_results.error) {
  3027. res->error(all_results.error->to_json());
  3028. return res;
  3029. } else {
  3030. for (auto & res : all_results.results) {
  3031. GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
  3032. responses.push_back(res->to_json());
  3033. }
  3034. }
  3035. // write JSON response
  3036. json root = format_response_rerank(
  3037. body,
  3038. meta->model_name,
  3039. responses,
  3040. is_tei_format,
  3041. documents,
  3042. top_n);
  3043. res->ok(root);
  3044. return res;
  3045. };
  3046. this->get_lora_adapters = [this](const server_http_req & req) {
  3047. auto res = create_response();
  3048. auto & rd = res->rd;
  3049. {
  3050. server_task task(SERVER_TASK_TYPE_GET_LORA);
  3051. task.id = rd.get_new_id();
  3052. rd.post_task(std::move(task));
  3053. }
  3054. // get the result
  3055. auto result = rd.next(req.should_stop);
  3056. if (!result) {
  3057. // connection was closed
  3058. GGML_ASSERT(req.should_stop());
  3059. return res;
  3060. }
  3061. if (result->is_error()) {
  3062. res->error(result->to_json());
  3063. return res;
  3064. }
  3065. GGML_ASSERT(dynamic_cast<server_task_result_get_lora*>(result.get()) != nullptr);
  3066. res->ok(result->to_json());
  3067. return res;
  3068. };
  3069. this->post_lora_adapters = [this](const server_http_req & req) {
  3070. auto res = create_response();
  3071. const json body = json::parse(req.body);
  3072. if (!body.is_array()) {
  3073. res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
  3074. return res;
  3075. }
  3076. auto & rd = res->rd;
  3077. {
  3078. server_task task(SERVER_TASK_TYPE_SET_LORA);
  3079. task.id = rd.get_new_id();
  3080. task.set_lora = parse_lora_request(body);
  3081. rd.post_task(std::move(task));
  3082. }
  3083. // get the result
  3084. auto result = rd.next(req.should_stop);
  3085. if (!result) {
  3086. // connection was closed
  3087. GGML_ASSERT(req.should_stop());
  3088. return res;
  3089. }
  3090. if (result->is_error()) {
  3091. res->error(result->to_json());
  3092. return res;
  3093. }
  3094. GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
  3095. res->ok(result->to_json());
  3096. return res;
  3097. };
  3098. }
  3099. std::unique_ptr<server_res_generator> server_routes::handle_slots_save(const server_http_req & req, int id_slot) {
  3100. auto res = create_response();
  3101. const json request_data = json::parse(req.body);
  3102. std::string filename = request_data.at("filename");
  3103. if (!fs_validate_filename(filename)) {
  3104. res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
  3105. return res;
  3106. }
  3107. std::string filepath = params.slot_save_path + filename;
  3108. auto & rd = res->rd;
  3109. {
  3110. server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
  3111. task.id = rd.get_new_id();
  3112. task.slot_action.slot_id = id_slot;
  3113. task.slot_action.filename = filename;
  3114. task.slot_action.filepath = filepath;
  3115. rd.post_task(std::move(task));
  3116. }
  3117. auto result = rd.next(req.should_stop);
  3118. if (!result) {
  3119. // connection was closed
  3120. GGML_ASSERT(req.should_stop());
  3121. return res;
  3122. }
  3123. if (result->is_error()) {
  3124. res->error(result->to_json());
  3125. return res;
  3126. }
  3127. res->ok(result->to_json());
  3128. return res;
  3129. }
  3130. std::unique_ptr<server_res_generator> server_routes::handle_slots_restore(const server_http_req & req, int id_slot) {
  3131. auto res = create_response();
  3132. const json request_data = json::parse(req.body);
  3133. std::string filename = request_data.at("filename");
  3134. if (!fs_validate_filename(filename)) {
  3135. res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
  3136. return res;
  3137. }
  3138. std::string filepath = params.slot_save_path + filename;
  3139. auto & rd = res->rd;
  3140. {
  3141. server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
  3142. task.id = rd.get_new_id();
  3143. task.slot_action.slot_id = id_slot;
  3144. task.slot_action.filename = filename;
  3145. task.slot_action.filepath = filepath;
  3146. rd.post_task(std::move(task));
  3147. }
  3148. auto result = rd.next(req.should_stop);
  3149. if (!result) {
  3150. // connection was closed
  3151. GGML_ASSERT(req.should_stop());
  3152. return res;
  3153. }
  3154. if (result->is_error()) {
  3155. res->error(result->to_json());
  3156. return res;
  3157. }
  3158. GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
  3159. res->ok(result->to_json());
  3160. return res;
  3161. }
  3162. std::unique_ptr<server_res_generator> server_routes::handle_slots_erase(const server_http_req & req, int id_slot) {
  3163. auto res = create_response();
  3164. auto & rd = res->rd;
  3165. {
  3166. server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
  3167. task.id = rd.get_new_id();
  3168. task.slot_action.slot_id = id_slot;
  3169. rd.post_task(std::move(task));
  3170. }
  3171. auto result = rd.next(req.should_stop);
  3172. if (!result) {
  3173. // connection was closed
  3174. GGML_ASSERT(req.should_stop());
  3175. return res;
  3176. }
  3177. if (result->is_error()) {
  3178. res->error(result->to_json());
  3179. return res;
  3180. }
  3181. GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
  3182. res->ok(result->to_json());
  3183. return res;
  3184. }
  3185. std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
  3186. auto res = create_response();
  3187. if (!params.embedding) {
  3188. res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
  3189. return res;
  3190. }
  3191. if (res_type != TASK_RESPONSE_TYPE_NONE && meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
  3192. res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
  3193. return res;
  3194. }
  3195. const json body = json::parse(req.body);
  3196. // for the shape of input/content, see tokenize_input_prompts()
  3197. json prompt;
  3198. if (body.count("input") != 0) {
  3199. prompt = body.at("input");
  3200. } else if (body.contains("content")) {
  3201. res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
  3202. prompt = body.at("content");
  3203. } else {
  3204. res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
  3205. return res;
  3206. }
  3207. bool use_base64 = false;
  3208. if (body.count("encoding_format") != 0) {
  3209. const std::string & format = body.at("encoding_format");
  3210. if (format == "base64") {
  3211. use_base64 = true;
  3212. } else if (format != "float") {
  3213. res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
  3214. return res;
  3215. }
  3216. }
  3217. auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  3218. for (const auto & tokens : tokenized_prompts) {
  3219. // this check is necessary for models that do not add BOS token to the input
  3220. if (tokens.empty()) {
  3221. res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
  3222. return res;
  3223. }
  3224. }
  3225. int embd_normalize = 2; // default to Euclidean/L2 norm
  3226. if (body.count("embd_normalize") != 0) {
  3227. embd_normalize = body.at("embd_normalize");
  3228. if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
  3229. SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type);
  3230. }
  3231. }
  3232. // create and queue the task
  3233. json responses = json::array();
  3234. auto & rd = res->rd;
  3235. {
  3236. std::vector<server_task> tasks;
  3237. for (size_t i = 0; i < tokenized_prompts.size(); i++) {
  3238. server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
  3239. task.id = rd.get_new_id();
  3240. task.tokens = std::move(tokenized_prompts[i]);
  3241. // OAI-compat
  3242. task.params.res_type = res_type;
  3243. task.params.embd_normalize = embd_normalize;
  3244. tasks.push_back(std::move(task));
  3245. }
  3246. rd.post_tasks(std::move(tasks));
  3247. }
  3248. // wait for the results
  3249. auto all_results = rd.wait_for_all(req.should_stop);
  3250. // collect results
  3251. if (all_results.is_terminated) {
  3252. return res; // connection is closed
  3253. } else if (all_results.error) {
  3254. res->error(all_results.error->to_json());
  3255. return res;
  3256. } else {
  3257. for (auto & res : all_results.results) {
  3258. GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
  3259. responses.push_back(res->to_json());
  3260. }
  3261. }
  3262. // write JSON response
  3263. json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
  3264. ? format_embeddings_response_oaicompat(body, meta->model_name, responses, use_base64)
  3265. : json(responses);
  3266. res->ok(root);
  3267. return res;
  3268. }