server-context.cpp 145 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620
  1. #include "server-context.h"
  2. #include "server-common.h"
  3. #include "server-http.h"
  4. #include "server-task.h"
  5. #include "server-queue.h"
  6. #include "arg.h"
  7. #include "common.h"
  8. #include "llama.h"
  9. #include "log.h"
  10. #include "sampling.h"
  11. #include "speculative.h"
  12. #include "mtmd.h"
  13. #include "mtmd-helper.h"
  14. #include <cstddef>
  15. #include <cinttypes>
  16. #include <memory>
  17. #include <unordered_set>
  18. // fix problem with std::min and std::max
  19. #if defined(_WIN32)
  20. #define WIN32_LEAN_AND_MEAN
  21. #ifndef NOMINMAX
  22. # define NOMINMAX
  23. #endif
  24. #include <windows.h>
  25. #endif
  26. using json = nlohmann::ordered_json;
  27. constexpr int HTTP_POLLING_SECONDS = 1;
  28. // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
  29. enum slot_state {
  30. SLOT_STATE_IDLE,
  31. SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
  32. SLOT_STATE_PROCESSING_PROMPT,
  33. SLOT_STATE_DONE_PROMPT,
  34. SLOT_STATE_GENERATING,
  35. };
  36. enum server_state {
  37. SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
  38. SERVER_STATE_READY, // Server is ready and model is loaded
  39. };
  40. static bool server_task_type_need_embd(server_task_type task_type) {
  41. switch (task_type) {
  42. case SERVER_TASK_TYPE_EMBEDDING:
  43. case SERVER_TASK_TYPE_RERANK:
  44. return true;
  45. default:
  46. return false;
  47. }
  48. }
  49. static bool server_task_type_need_logits(server_task_type task_type) {
  50. switch (task_type) {
  51. case SERVER_TASK_TYPE_COMPLETION:
  52. case SERVER_TASK_TYPE_INFILL:
  53. return true;
  54. default:
  55. return false;
  56. }
  57. }
  58. struct server_slot {
  59. int id;
  60. llama_batch batch_spec = {};
  61. // TODO: change to unique_ptrs for consistency:
  62. llama_context * ctx = nullptr;
  63. llama_context * ctx_dft = nullptr;
  64. // multimodal
  65. mtmd_context * mctx = nullptr;
  66. common_speculative * spec = nullptr;
  67. std::unique_ptr<const server_task> task;
  68. std::unique_ptr<const server_task> task_prev; // used for debugging
  69. // used to determine the slot that has been used the longest
  70. int64_t t_last_used = -1;
  71. // generation props
  72. int32_t n_ctx = 0; // context size per slot
  73. int32_t n_keep = 0;
  74. int32_t n_decoded = 0;
  75. int32_t n_remaining = -1;
  76. int32_t i_batch = -1;
  77. int32_t n_prompt_tokens_cache = 0;
  78. int32_t n_prompt_tokens_processed = 0;
  79. size_t last_nl_pos = 0;
  80. std::string generated_text;
  81. llama_tokens generated_tokens;
  82. common_chat_msg chat_msg;
  83. std::vector<completion_token_output> generated_token_probs;
  84. bool has_next_token = true;
  85. bool has_new_line = false;
  86. bool truncated = false;
  87. stop_type stop;
  88. std::string stopping_word;
  89. // state
  90. slot_state state = SLOT_STATE_IDLE;
  91. server_prompt prompt;
  92. void prompt_save(server_prompt_cache & prompt_cache) const {
  93. GGML_ASSERT(prompt.data.size() == 0);
  94. const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
  95. SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
  96. (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
  97. auto * cur = prompt_cache.alloc(prompt, cur_size);
  98. if (cur == nullptr) {
  99. return;
  100. }
  101. llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
  102. }
  103. bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
  104. bool res = prompt_cache.load(prompt, tokens, ctx, id);
  105. if (!res) {
  106. SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
  107. }
  108. return res;
  109. }
  110. std::vector<common_adapter_lora_info> lora;
  111. int32_t alora_invocation_start = -1;
  112. // sampling
  113. json json_schema;
  114. struct common_sampler * smpl = nullptr;
  115. llama_token sampled;
  116. common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  117. std::vector<std::string> generated_tool_call_ids;
  118. // stats
  119. size_t n_sent_text = 0; // number of sent text character
  120. int64_t t_start_process_prompt;
  121. int64_t t_start_generation;
  122. double t_prompt_processing; // ms
  123. double t_token_generation; // ms
  124. std::function<void(int)> callback_on_release;
  125. // Speculative decoding stats
  126. int32_t n_draft_total = 0; // Total draft tokens generated
  127. int32_t n_draft_accepted = 0; // Draft tokens actually accepted
  128. void reset() {
  129. SLT_DBG(*this, "%s", "\n");
  130. n_prompt_tokens_cache = 0;
  131. last_nl_pos = 0;
  132. generated_text = "";
  133. has_new_line = false;
  134. truncated = false;
  135. stop = STOP_TYPE_NONE;
  136. stopping_word = "";
  137. n_sent_text = 0;
  138. chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
  139. generated_tokens.clear();
  140. generated_token_probs.clear();
  141. chat_msg = {};
  142. json_schema = json();
  143. generated_tool_call_ids.clear();
  144. // clear speculative decoding stats
  145. n_draft_total = 0;
  146. n_draft_accepted = 0;
  147. task.reset();
  148. task_prev.reset();
  149. // clear alora start
  150. alora_invocation_start = -1;
  151. }
  152. bool need_embd() const {
  153. GGML_ASSERT(task);
  154. return server_task_type_need_embd(task->type);
  155. }
  156. bool need_logits() const {
  157. GGML_ASSERT(task);
  158. return server_task_type_need_logits(task->type);
  159. }
  160. // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
  161. // also we cannot split if the pooling would require any past tokens
  162. bool can_split() const {
  163. return
  164. !need_embd() ||
  165. (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
  166. }
  167. bool can_batch_with(server_slot & other_slot) const {
  168. GGML_ASSERT(task);
  169. return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora);
  170. }
  171. bool has_budget(const common_params & global_params) {
  172. GGML_ASSERT(task);
  173. if (task->params.n_predict == -1 && global_params.n_predict == -1) {
  174. return true; // limitless
  175. }
  176. n_remaining = -1;
  177. if (task->params.n_predict != -1) {
  178. n_remaining = task->params.n_predict - n_decoded;
  179. } else if (global_params.n_predict != -1) {
  180. n_remaining = global_params.n_predict - n_decoded;
  181. }
  182. return n_remaining > 0; // no budget
  183. }
  184. bool is_processing() const {
  185. return state != SLOT_STATE_IDLE;
  186. }
  187. bool can_speculate() const {
  188. return ctx_dft;
  189. }
  190. void add_token(const completion_token_output & token) {
  191. if (!is_processing()) {
  192. SLT_WRN(*this, "%s", "slot is not processing\n");
  193. return;
  194. }
  195. generated_token_probs.push_back(token);
  196. }
  197. void release() {
  198. if (is_processing()) {
  199. GGML_ASSERT(task);
  200. SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
  201. t_last_used = ggml_time_us();
  202. t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
  203. state = SLOT_STATE_IDLE;
  204. task_prev = std::move(task);
  205. task.reset();
  206. callback_on_release(id);
  207. }
  208. }
  209. result_timings get_timings() const {
  210. result_timings timings;
  211. timings.cache_n = n_prompt_tokens_cache;
  212. timings.prompt_n = n_prompt_tokens_processed;
  213. timings.prompt_ms = t_prompt_processing;
  214. timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
  215. timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
  216. timings.predicted_n = n_decoded;
  217. timings.predicted_ms = t_token_generation;
  218. timings.predicted_per_token_ms = t_token_generation / n_decoded;
  219. timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
  220. // Add speculative metrics
  221. if (n_draft_total > 0) {
  222. timings.draft_n = n_draft_total;
  223. timings.draft_n_accepted = n_draft_accepted;
  224. }
  225. return timings;
  226. }
  227. const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
  228. GGML_ASSERT(task);
  229. auto previous_msg = chat_msg;
  230. SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
  231. auto new_msg = common_chat_parse(
  232. generated_text,
  233. /* is_partial= */ stop != STOP_TYPE_EOS,
  234. task->params.oaicompat_chat_syntax);
  235. if (!new_msg.empty()) {
  236. new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
  237. chat_msg = new_msg;
  238. diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
  239. }
  240. return chat_msg;
  241. }
  242. size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
  243. GGML_ASSERT(task);
  244. size_t stop_pos = std::string::npos;
  245. for (const std::string & word : task->params.antiprompt) {
  246. size_t pos;
  247. if (is_full_stop) {
  248. const size_t tmp = word.size() + last_token_size;
  249. const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
  250. pos = text.find(word, from_pos);
  251. } else {
  252. // otherwise, partial stop
  253. pos = string_find_partial_stop(text, word);
  254. }
  255. if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
  256. if (is_full_stop) {
  257. stop = STOP_TYPE_WORD;
  258. stopping_word = word;
  259. has_next_token = false;
  260. }
  261. stop_pos = pos;
  262. }
  263. }
  264. return stop_pos;
  265. }
  266. void print_timings() const {
  267. const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
  268. const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
  269. const double t_gen = t_token_generation / n_decoded;
  270. const double n_gen_second = 1e3 / t_token_generation * n_decoded;
  271. SLT_INF(*this,
  272. "\n"
  273. "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
  274. " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
  275. " total time = %10.2f ms / %5d tokens\n",
  276. t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
  277. t_token_generation, n_decoded, t_gen, n_gen_second,
  278. t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
  279. if (n_draft_total > 0) {
  280. const float draft_ratio = (float) n_draft_accepted / n_draft_total;
  281. SLT_INF(*this,
  282. "\n"
  283. "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
  284. draft_ratio, n_draft_accepted, n_draft_total
  285. );
  286. }
  287. }
  288. json to_json(bool only_metrics = false) const {
  289. json res;
  290. res = {
  291. {"id", id},
  292. {"n_ctx", n_ctx},
  293. {"speculative", can_speculate()},
  294. {"is_processing", is_processing()},
  295. };
  296. const auto & ptask = task ? task : task_prev;
  297. if (ptask) {
  298. res["id_task"] = ptask->id;
  299. res["params"] = ptask->params.to_json(only_metrics);
  300. res["next_token"] = {
  301. {
  302. {"has_next_token", has_next_token},
  303. {"has_new_line", has_new_line},
  304. {"n_remain", n_remaining},
  305. {"n_decoded", n_decoded},
  306. }
  307. };
  308. if (!only_metrics) {
  309. res["prompt"] = ptask->tokens.detokenize(ctx, true);
  310. res["generated"] = generated_text;
  311. }
  312. }
  313. return res;
  314. }
  315. };
  316. //
  317. // server_metrics
  318. //
  319. struct server_metrics {
  320. int64_t t_start = 0;
  321. uint64_t n_prompt_tokens_processed_total = 0;
  322. uint64_t t_prompt_processing_total = 0;
  323. uint64_t n_tokens_predicted_total = 0;
  324. uint64_t t_tokens_generation_total = 0;
  325. uint64_t n_tokens_max = 0;
  326. uint64_t n_prompt_tokens_processed = 0;
  327. uint64_t t_prompt_processing = 0;
  328. uint64_t n_tokens_predicted = 0;
  329. uint64_t t_tokens_generation = 0;
  330. uint64_t n_decode_total = 0;
  331. uint64_t n_busy_slots_total = 0;
  332. void init() {
  333. t_start = ggml_time_us();
  334. }
  335. void on_prompt_eval(const server_slot & slot) {
  336. n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
  337. n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
  338. t_prompt_processing += slot.t_prompt_processing;
  339. t_prompt_processing_total += slot.t_prompt_processing;
  340. n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
  341. }
  342. void on_prediction(const server_slot & slot) {
  343. n_tokens_predicted_total += slot.n_decoded;
  344. n_tokens_predicted += slot.n_decoded;
  345. t_tokens_generation += slot.t_token_generation;
  346. t_tokens_generation_total += slot.t_token_generation;
  347. }
  348. void on_decoded(const std::vector<server_slot> & slots) {
  349. n_decode_total++;
  350. for (const auto & slot : slots) {
  351. if (slot.is_processing()) {
  352. n_busy_slots_total++;
  353. }
  354. n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
  355. }
  356. }
  357. void reset_bucket() {
  358. n_prompt_tokens_processed = 0;
  359. t_prompt_processing = 0;
  360. n_tokens_predicted = 0;
  361. t_tokens_generation = 0;
  362. }
  363. };
  364. //
  365. // server_context_impl (private implementation)
  366. //
  367. struct server_context_impl {
  368. common_params params_base;
  369. // note: keep these alive - they determine the lifetime of the model, context, etc.
  370. common_init_result llama_init;
  371. common_init_result llama_init_dft;
  372. llama_model * model = nullptr;
  373. llama_context * ctx = nullptr;
  374. // multimodal
  375. mtmd_context * mctx = nullptr;
  376. const llama_vocab * vocab = nullptr;
  377. bool vocab_dft_compatible = true;
  378. llama_model * model_dft = nullptr;
  379. llama_context_params cparams_dft;
  380. llama_batch batch {};
  381. bool add_bos_token = true;
  382. int32_t n_ctx; // total context for all clients / slots
  383. // slots / clients
  384. std::vector<server_slot> slots;
  385. int slots_debug = 0;
  386. server_queue queue_tasks;
  387. server_response queue_results;
  388. std::unique_ptr<server_prompt_cache> prompt_cache;
  389. server_metrics metrics;
  390. // Necessary similarity of prompt for slot selection
  391. float slot_prompt_similarity = 0.0f;
  392. common_chat_templates_ptr chat_templates;
  393. oaicompat_parser_options oai_parser_opt;
  394. ~server_context_impl() {
  395. mtmd_free(mctx);
  396. // Clear any sampling context
  397. for (server_slot & slot : slots) {
  398. common_sampler_free(slot.smpl);
  399. slot.smpl = nullptr;
  400. llama_free(slot.ctx_dft);
  401. slot.ctx_dft = nullptr;
  402. common_speculative_free(slot.spec);
  403. slot.spec = nullptr;
  404. llama_batch_free(slot.batch_spec);
  405. }
  406. llama_batch_free(batch);
  407. }
  408. // load the model and initialize llama_context
  409. bool load_model(const common_params & params) {
  410. SRV_INF("loading model '%s'\n", params.model.path.c_str());
  411. params_base = params;
  412. llama_init = common_init_from_params(params_base);
  413. model = llama_init.model.get();
  414. ctx = llama_init.context.get();
  415. if (model == nullptr) {
  416. SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
  417. return false;
  418. }
  419. vocab = llama_model_get_vocab(model);
  420. n_ctx = llama_n_ctx(ctx);
  421. add_bos_token = llama_vocab_get_add_bos(vocab);
  422. if (params_base.has_speculative()) {
  423. SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
  424. auto params_dft = params_base;
  425. params_dft.devices = params_base.speculative.devices;
  426. params_dft.model = params_base.speculative.model;
  427. params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
  428. params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
  429. params_dft.n_parallel = 1;
  430. params_dft.cache_type_k = params_base.speculative.cache_type_k;
  431. params_dft.cache_type_v = params_base.speculative.cache_type_v;
  432. params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
  433. params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
  434. params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
  435. llama_init_dft = common_init_from_params(params_dft);
  436. model_dft = llama_init_dft.model.get();
  437. if (model_dft == nullptr) {
  438. SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
  439. return false;
  440. }
  441. vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
  442. if (!vocab_dft_compatible) {
  443. SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
  444. }
  445. const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
  446. cparams_dft = common_context_params_to_llama(params_dft);
  447. cparams_dft.n_batch = n_ctx_dft;
  448. // the context is not needed - we will create one for each slot
  449. llama_init_dft.context.reset();
  450. }
  451. chat_templates = common_chat_templates_init(model, params_base.chat_template);
  452. try {
  453. common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs);
  454. } catch (const std::exception & e) {
  455. SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
  456. SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
  457. chat_templates = common_chat_templates_init(model, "chatml");
  458. }
  459. std::string & mmproj_path = params_base.mmproj.path;
  460. if (!mmproj_path.empty()) {
  461. mtmd_helper_log_set(common_log_default_callback, nullptr);
  462. mtmd_context_params mparams = mtmd_context_params_default();
  463. mparams.use_gpu = params_base.mmproj_use_gpu;
  464. mparams.print_timings = false;
  465. mparams.n_threads = params_base.cpuparams.n_threads;
  466. mparams.flash_attn_type = params_base.flash_attn_type;
  467. mparams.warmup = params_base.warmup;
  468. mparams.image_min_tokens = params_base.image_min_tokens;
  469. mparams.image_max_tokens = params_base.image_max_tokens;
  470. mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
  471. if (mctx == nullptr) {
  472. SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
  473. return false;
  474. }
  475. SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
  476. if (params_base.ctx_shift) {
  477. params_base.ctx_shift = false;
  478. SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
  479. }
  480. if (params_base.n_cache_reuse) {
  481. params_base.n_cache_reuse = 0;
  482. SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
  483. }
  484. if (params_base.has_speculative()) {
  485. SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
  486. return false;
  487. }
  488. }
  489. if (!llama_memory_can_shift(llama_get_memory(ctx))) {
  490. if (params_base.ctx_shift) {
  491. params_base.ctx_shift = false;
  492. SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
  493. }
  494. if (params_base.n_cache_reuse) {
  495. params_base.n_cache_reuse = 0;
  496. SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
  497. }
  498. }
  499. return true;
  500. }
  501. // initialize slots and server-related data
  502. void init() {
  503. // wiring up server queues
  504. queue_tasks.on_new_task([this](server_task && task) {
  505. process_single_task(std::move(task));
  506. });
  507. queue_tasks.on_update_slots([this]() {
  508. update_slots();
  509. });
  510. // Necessary similarity of prompt for slot selection
  511. slot_prompt_similarity = params_base.slot_prompt_similarity;
  512. // setup slots
  513. SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
  514. const int n_ctx_train = llama_model_n_ctx_train(model);
  515. int n_ctx_slot = llama_n_ctx_seq(ctx);
  516. if (n_ctx_slot > n_ctx_train) {
  517. SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
  518. n_ctx_slot = n_ctx_train;
  519. }
  520. for (int i = 0; i < params_base.n_parallel; i++) {
  521. server_slot slot;
  522. slot.id = i;
  523. slot.ctx = ctx;
  524. slot.n_ctx = n_ctx_slot;
  525. slot.mctx = mctx;
  526. slot.prompt.tokens.has_mtmd = mctx != nullptr;
  527. if (model_dft) {
  528. slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
  529. // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
  530. slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
  531. if (slot.ctx_dft == nullptr) {
  532. SRV_ERR("%s", "failed to create draft context\n");
  533. return;
  534. }
  535. slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
  536. if (slot.spec == nullptr) {
  537. SRV_ERR("%s", "failed to create speculator\n");
  538. return;
  539. }
  540. for (auto & pair : params_base.speculative.replacements) {
  541. common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
  542. }
  543. }
  544. SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
  545. slot.callback_on_release = [this](int) {
  546. queue_tasks.pop_deferred_task();
  547. };
  548. slot.reset();
  549. slots.push_back(std::move(slot));
  550. }
  551. {
  552. const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG");
  553. slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0;
  554. if (slots_debug) {
  555. SRV_WRN("slots debug = %d\n", slots_debug);
  556. }
  557. }
  558. // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
  559. // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
  560. {
  561. const int32_t n_batch = llama_n_batch(ctx);
  562. batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
  563. }
  564. metrics.init();
  565. if (params_base.cache_ram_mib != 0) {
  566. if (params_base.cache_ram_mib < 0) {
  567. SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
  568. } else {
  569. SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
  570. }
  571. SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
  572. prompt_cache = std::make_unique<server_prompt_cache>(params_base.cache_ram_mib, n_ctx);
  573. } else {
  574. SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
  575. }
  576. SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
  577. // thinking is enabled if:
  578. // 1. It's not explicitly disabled (reasoning_budget == 0)
  579. // 2. The chat template supports it
  580. const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
  581. SRV_INF("thinking = %d\n", enable_thinking);
  582. oai_parser_opt = {
  583. /* use_jinja */ params_base.use_jinja,
  584. /* prefill_assistant */ params_base.prefill_assistant,
  585. /* reasoning_format */ params_base.reasoning_format,
  586. /* chat_template_kwargs */ params_base.default_template_kwargs,
  587. /* common_chat_templates */ chat_templates.get(),
  588. /* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
  589. /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
  590. /* enable_thinking */ enable_thinking,
  591. };
  592. // print sample chat example to make it clear which template is used
  593. LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
  594. common_chat_templates_source(chat_templates.get()),
  595. common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
  596. }
  597. server_slot * get_slot_by_id(int id) {
  598. for (server_slot & slot : slots) {
  599. if (slot.id == id) {
  600. return &slot;
  601. }
  602. }
  603. return nullptr;
  604. }
  605. server_slot * get_available_slot(const server_task & task) {
  606. server_slot * ret = nullptr;
  607. bool update_cache = false;
  608. // find the slot that has at least n% prompt similarity
  609. if (ret == nullptr && slot_prompt_similarity != 0.0f) {
  610. float sim_best = 0;
  611. for (server_slot & slot : slots) {
  612. // skip the slot if it is not available
  613. if (slot.is_processing()) {
  614. continue;
  615. }
  616. const auto & tokens = slot.prompt.tokens;
  617. // skip the slot if it does not contains cached tokens
  618. if (tokens.empty()) {
  619. continue;
  620. }
  621. // fraction of the Longest Common Prefix length with respect to the input prompt length
  622. const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size();
  623. // select the current slot if the criteria match
  624. if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
  625. sim_best = sim_cur;
  626. ret = &slot;
  627. }
  628. }
  629. if (ret != nullptr) {
  630. const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
  631. SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
  632. sim_best, slot_prompt_similarity, f_keep);
  633. // if we are about to lose a large portion of the existing context - save it in the prompt cache
  634. if (f_keep < 0.5f) {
  635. update_cache = true;
  636. }
  637. }
  638. }
  639. // find the slot that has been least recently used
  640. if (ret == nullptr) {
  641. int64_t t_last = -1;
  642. for (server_slot & slot : slots) {
  643. // skip the slot if it is not available
  644. if (slot.is_processing()) {
  645. continue;
  646. }
  647. // select the current slot if the criteria match
  648. if (!ret || slot.t_last_used <= t_last) {
  649. t_last = slot.t_last_used;
  650. ret = &slot;
  651. }
  652. }
  653. if (ret != nullptr) {
  654. SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
  655. update_cache = true;
  656. }
  657. }
  658. if (ret) {
  659. const auto & tokens = ret->prompt.tokens;
  660. update_cache = update_cache && prompt_cache;
  661. // cache prompts only for completion tasks
  662. update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
  663. // don't update the cache if the slot's context is empty
  664. update_cache = update_cache && tokens.size() > 0;
  665. // TODO: mtmd does not support prompt cache
  666. update_cache = update_cache && (ret->mctx == nullptr);
  667. if (update_cache) {
  668. SRV_WRN("%s", "updating prompt cache\n");
  669. const int64_t t_start = ggml_time_us();
  670. ret->prompt_save(*prompt_cache);
  671. if (!ret->prompt_load(*prompt_cache, task.tokens)) {
  672. clear_slot(*ret);
  673. }
  674. prompt_cache->update();
  675. SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
  676. }
  677. }
  678. return ret;
  679. }
  680. void clear_slot(server_slot & slot) const {
  681. GGML_ASSERT(!slot.is_processing());
  682. SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
  683. llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
  684. slot.prompt.tokens.clear();
  685. }
  686. // return true if at least one slot has been cleared
  687. // TODO: improve logic
  688. // - smarter decision which slot to clear (LRU or longest prompt?)
  689. // - move slot to level 2 cache instead of removing?
  690. // - instead of purging, try to store and resume later?
  691. bool try_clear_idle_slots() {
  692. bool res = false;
  693. if (!params_base.kv_unified) {
  694. return res;
  695. }
  696. for (auto & slot : slots) {
  697. if (slot.is_processing()) {
  698. continue;
  699. }
  700. if (slot.prompt.n_tokens() > 0) {
  701. SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
  702. clear_slot(slot);
  703. res = true;
  704. // clear slots one by one
  705. break;
  706. }
  707. }
  708. return res;
  709. }
  710. bool launch_slot_with_task(server_slot & slot, server_task && task) {
  711. slot.reset();
  712. if (!are_lora_equal(task.params.lora, slot.lora)) {
  713. // if lora has changed, check to see if the cache should be cleared
  714. if (lora_should_clear_cache(slot.lora, task.params.lora)) {
  715. SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
  716. slot.prompt.tokens.clear();
  717. } else {
  718. SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size());
  719. }
  720. slot.lora = task.params.lora;
  721. }
  722. // if using alora, make sure it's only a single one requested and active
  723. size_t alora_invocation_start = task.tokens.size();
  724. if (lora_all_alora(slot.lora)) {
  725. const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
  726. // TODO: This will error out if a user requests two aloras, but only
  727. // provides the activation string for one. We could, instead search
  728. // for all requested alora activation strings and then either keep
  729. // only the last one, or reject if multiple are found.
  730. if (enabled_ids.size() != 1) {
  731. send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
  732. return false;
  733. }
  734. const auto & lora = slot.lora[enabled_ids[0]].ptr;
  735. // get the pointer and count for the invocation tokens
  736. const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
  737. const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
  738. // scan backwards through the prompt tokens to find the last
  739. // occurrence of the invocation sequence
  740. int match_idx = static_cast<int>(n_invocation_tokens) - 1;
  741. for (int i = task.tokens.size() - 1; i >= 0; --i) {
  742. // the token in this position matches the next token to find in
  743. // the invocation sequence
  744. if (task.tokens[i] == invocation_tokens[match_idx]) {
  745. // if it's a full match, we've found the start
  746. if (match_idx == 0) {
  747. alora_invocation_start = i;
  748. break;
  749. }
  750. // otherwise, check the next token in the sequence
  751. --match_idx;
  752. } else {
  753. // no match in this position, so start looking over again
  754. match_idx = static_cast<int>(n_invocation_tokens) - 1;
  755. }
  756. }
  757. // if the activation string is not found, disable the alora
  758. if (alora_invocation_start == task.tokens.size()) {
  759. SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
  760. slot.lora[enabled_ids[0]].scale = 0.0f;
  761. } else {
  762. SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
  763. slot.alora_invocation_start = alora_invocation_start;
  764. }
  765. }
  766. if (!task.tokens.validate(ctx)) {
  767. send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
  768. return false;
  769. }
  770. SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
  771. // initialize samplers
  772. {
  773. if (slot.smpl != nullptr) {
  774. common_sampler_free(slot.smpl);
  775. }
  776. slot.smpl = common_sampler_init(model, task.params.sampling);
  777. if (slot.smpl == nullptr) {
  778. // for now, the only error that may happen here is invalid grammar
  779. send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
  780. return false;
  781. }
  782. SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
  783. }
  784. // initialize draft batch
  785. // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
  786. if (slot.ctx_dft) {
  787. llama_batch_free(slot.batch_spec);
  788. slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
  789. }
  790. slot.task = std::make_unique<const server_task>(std::move(task));
  791. slot.state = SLOT_STATE_STARTED;
  792. SLT_INF(slot, "%s", "processing task\n");
  793. return true;
  794. }
  795. bool process_token(completion_token_output & result, server_slot & slot) {
  796. // remember which tokens were sampled - used for repetition penalties during sampling
  797. const std::string token_str = result.text_to_send;
  798. slot.sampled = result.tok;
  799. slot.generated_text += token_str;
  800. if (slot.task->params.return_tokens) {
  801. slot.generated_tokens.push_back(result.tok);
  802. }
  803. slot.has_next_token = true;
  804. // check if there is incomplete UTF-8 character at the end
  805. bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
  806. // search stop word and delete it
  807. if (!incomplete) {
  808. size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
  809. const std::string str_test = slot.generated_text.substr(pos);
  810. bool send_text = true;
  811. size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
  812. if (stop_pos != std::string::npos) {
  813. slot.generated_text.erase(
  814. slot.generated_text.begin() + pos + stop_pos,
  815. slot.generated_text.end());
  816. pos = std::min(slot.n_sent_text, slot.generated_text.size());
  817. } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) {
  818. stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
  819. send_text = stop_pos == std::string::npos;
  820. }
  821. // check if there is any token to predict
  822. if (send_text) {
  823. // no send the stop word in the response
  824. result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
  825. slot.n_sent_text += result.text_to_send.size();
  826. // add the token to slot queue and cache
  827. } else {
  828. result.text_to_send = "";
  829. }
  830. slot.add_token(result);
  831. if (slot.task->params.stream) {
  832. send_partial_response(slot, result, false);
  833. }
  834. }
  835. if (incomplete) {
  836. slot.has_next_token = true;
  837. }
  838. // if context shifting is disabled, make sure that we don't run out of context
  839. if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
  840. slot.truncated = true;
  841. slot.stop = STOP_TYPE_LIMIT;
  842. slot.has_next_token = false;
  843. SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
  844. slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
  845. }
  846. // check the limits
  847. if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
  848. slot.stop = STOP_TYPE_LIMIT;
  849. slot.has_next_token = false;
  850. SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
  851. }
  852. if (slot.has_new_line) {
  853. // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
  854. if (slot.task->params.n_indent > 0) {
  855. // check the current indentation
  856. // TODO: improve by not doing it more than once for each new line
  857. if (slot.last_nl_pos > 0) {
  858. size_t pos = slot.last_nl_pos;
  859. int n_indent = 0;
  860. while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
  861. n_indent++;
  862. pos++;
  863. }
  864. if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
  865. slot.stop = STOP_TYPE_LIMIT;
  866. slot.has_next_token = false;
  867. // cut the last line
  868. slot.generated_text.erase(pos, std::string::npos);
  869. SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
  870. }
  871. }
  872. // find the next new line
  873. {
  874. const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
  875. if (pos != std::string::npos) {
  876. slot.last_nl_pos = pos + 1;
  877. }
  878. }
  879. }
  880. }
  881. // check if there is a new line in the generated text
  882. if (result.text_to_send.find('\n') != std::string::npos) {
  883. slot.has_new_line = true;
  884. // if we have seen a new line, we stop after a certain time limit, but only upon another new line
  885. if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
  886. slot.stop = STOP_TYPE_LIMIT;
  887. slot.has_next_token = false;
  888. SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
  889. }
  890. }
  891. if (llama_vocab_is_eog(vocab, result.tok)) {
  892. slot.stop = STOP_TYPE_EOS;
  893. slot.has_next_token = false;
  894. SLT_DBG(slot, "%s", "stopped by EOS\n");
  895. }
  896. SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
  897. return slot.has_next_token; // continue
  898. }
  899. void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
  900. size_t n_probs = slot.task->params.sampling.n_probs;
  901. size_t n_vocab = llama_vocab_n_tokens(vocab);
  902. if (post_sampling) {
  903. const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
  904. const size_t max_probs = cur_p->size;
  905. // set probability for sampled token
  906. for (size_t i = 0; i < max_probs; i++) {
  907. if (cur_p->data[i].id == result.tok) {
  908. result.prob = cur_p->data[i].p;
  909. break;
  910. }
  911. }
  912. // set probability for top n_probs tokens
  913. result.probs.reserve(max_probs);
  914. for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
  915. result.probs.push_back({
  916. cur_p->data[i].id,
  917. common_token_to_piece(ctx, cur_p->data[i].id, special),
  918. cur_p->data[i].p
  919. });
  920. }
  921. } else {
  922. // TODO: optimize this with min-p optimization
  923. std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
  924. // set probability for sampled token
  925. for (size_t i = 0; i < n_vocab; i++) {
  926. // set probability for sampled token
  927. if (cur[i].id == result.tok) {
  928. result.prob = cur[i].p;
  929. break;
  930. }
  931. }
  932. // set probability for top n_probs tokens
  933. result.probs.reserve(n_probs);
  934. for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
  935. result.probs.push_back({
  936. cur[i].id,
  937. common_token_to_piece(ctx, cur[i].id, special),
  938. cur[i].p
  939. });
  940. }
  941. }
  942. }
  943. void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
  944. send_error(task.id, error, type);
  945. }
  946. void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
  947. send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
  948. }
  949. void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
  950. SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
  951. if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
  952. GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
  953. }
  954. auto res = std::make_unique<server_task_result_error>();
  955. res->id = id_task;
  956. res->err_type = type;
  957. res->err_msg = error;
  958. res->n_prompt_tokens = n_prompt_tokens;
  959. res->n_ctx = n_ctx;
  960. queue_results.send(std::move(res));
  961. }
  962. // if multimodal is enabled, send an error and return false
  963. bool check_no_mtmd(const int id_task) {
  964. if (mctx) {
  965. send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
  966. return false;
  967. }
  968. return true;
  969. }
  970. void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
  971. auto res = std::make_unique<server_task_result_cmpl_partial>();
  972. res->id = slot.task->id;
  973. res->index = slot.task->index;
  974. if (is_progress) {
  975. res->is_progress = true;
  976. res->progress.total = slot.task->n_tokens();
  977. res->progress.cache = slot.n_prompt_tokens_cache;
  978. res->progress.processed = slot.prompt.tokens.size();
  979. res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
  980. } else {
  981. res->content = tkn.text_to_send;
  982. res->tokens = { tkn.tok };
  983. slot.update_chat_msg(res->oaicompat_msg_diffs);
  984. }
  985. res->n_decoded = slot.n_decoded;
  986. res->n_prompt_tokens = slot.task->n_tokens();
  987. res->post_sampling_probs = slot.task->params.post_sampling_probs;
  988. res->verbose = slot.task->params.verbose;
  989. res->res_type = slot.task->params.res_type;
  990. res->oaicompat_model = slot.task->params.oaicompat_model;
  991. res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
  992. // populate res.probs_output
  993. if (slot.task->params.sampling.n_probs > 0) {
  994. res->prob_output = tkn; // copy the token probs
  995. }
  996. // populate timings if this is final response or timings_per_token is enabled
  997. if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
  998. res->timings = slot.get_timings();
  999. }
  1000. queue_results.send(std::move(res));
  1001. }
  1002. void send_final_response(server_slot & slot) {
  1003. auto res = std::make_unique<server_task_result_cmpl_final>();
  1004. res->id = slot.task->id;
  1005. res->id_slot = slot.id;
  1006. res->index = slot.task->index;
  1007. res->content = slot.generated_text;
  1008. res->tokens = std::move(slot.generated_tokens);
  1009. res->timings = slot.get_timings();
  1010. res->prompt = slot.task->tokens.detokenize(ctx, true);
  1011. res->response_fields = std::move(slot.task->params.response_fields);
  1012. res->truncated = slot.truncated;
  1013. res->n_decoded = slot.n_decoded;
  1014. res->n_prompt_tokens = slot.task->n_tokens();
  1015. res->n_tokens_cached = slot.prompt.n_tokens();
  1016. res->has_new_line = slot.has_new_line;
  1017. res->stopping_word = slot.stopping_word;
  1018. res->stop = slot.stop;
  1019. res->post_sampling_probs = slot.task->params.post_sampling_probs;
  1020. res->verbose = slot.task->params.verbose;
  1021. res->stream = slot.task->params.stream;
  1022. res->include_usage = slot.task->params.include_usage;
  1023. res->res_type = slot.task->params.res_type;
  1024. res->oaicompat_model = slot.task->params.oaicompat_model;
  1025. res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
  1026. res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
  1027. // populate res.probs_output
  1028. if (slot.task->params.sampling.n_probs > 0) {
  1029. if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
  1030. const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
  1031. size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
  1032. res->probs_output = std::vector<completion_token_output>(
  1033. slot.generated_token_probs.begin(),
  1034. slot.generated_token_probs.end() - safe_offset);
  1035. } else {
  1036. res->probs_output = std::vector<completion_token_output>(
  1037. slot.generated_token_probs.begin(),
  1038. slot.generated_token_probs.end());
  1039. }
  1040. }
  1041. res->generation_params = slot.task->params; // copy the parameters
  1042. queue_results.send(std::move(res));
  1043. }
  1044. void send_embedding(const server_slot & slot, const llama_batch & batch) {
  1045. auto res = std::make_unique<server_task_result_embd>();
  1046. res->id = slot.task->id;
  1047. res->index = slot.task->index;
  1048. res->n_tokens = slot.task->n_tokens();
  1049. res->res_type = slot.task->params.res_type;
  1050. const int n_embd = llama_model_n_embd(model);
  1051. std::vector<float> embd_res(n_embd, 0.0f);
  1052. for (int i = 0; i < batch.n_tokens; ++i) {
  1053. if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
  1054. continue;
  1055. }
  1056. const float * embd = nullptr;
  1057. if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
  1058. embd = llama_get_embeddings_ith(ctx, i);
  1059. } else {
  1060. embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
  1061. }
  1062. if (embd == nullptr) {
  1063. SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
  1064. res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
  1065. continue;
  1066. }
  1067. // normalize only when there is pooling
  1068. if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
  1069. common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
  1070. res->embedding.push_back(embd_res);
  1071. break;
  1072. }
  1073. res->embedding.emplace_back(embd, embd + n_embd);
  1074. }
  1075. SLT_DBG(slot, "%s", "sending embeddings\n");
  1076. queue_results.send(std::move(res));
  1077. }
  1078. void send_rerank(const server_slot & slot, const llama_batch & batch) {
  1079. auto res = std::make_unique<server_task_result_rerank>();
  1080. res->id = slot.task->id;
  1081. res->index = slot.task->index;
  1082. res->n_tokens = slot.task->n_tokens();
  1083. for (int i = 0; i < batch.n_tokens; ++i) {
  1084. if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
  1085. continue;
  1086. }
  1087. const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
  1088. if (embd == NULL) {
  1089. embd = llama_get_embeddings_ith(ctx, i);
  1090. }
  1091. if (embd == NULL) {
  1092. SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
  1093. res->score = -1e6;
  1094. continue;
  1095. }
  1096. res->score = embd[0];
  1097. }
  1098. SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
  1099. queue_results.send(std::move(res));
  1100. }
  1101. //
  1102. // Functions to process the task
  1103. //
  1104. void process_single_task(server_task && task) {
  1105. switch (task.type) {
  1106. case SERVER_TASK_TYPE_COMPLETION:
  1107. case SERVER_TASK_TYPE_INFILL:
  1108. case SERVER_TASK_TYPE_EMBEDDING:
  1109. case SERVER_TASK_TYPE_RERANK:
  1110. {
  1111. const int id_slot = task.id_slot;
  1112. server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
  1113. if (slot == nullptr) {
  1114. // if no slot is available, we defer this task for processing later
  1115. SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
  1116. queue_tasks.defer(std::move(task));
  1117. break;
  1118. }
  1119. if (slot->is_processing()) {
  1120. // if requested slot is unavailable, we defer this task for processing later
  1121. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1122. queue_tasks.defer(std::move(task));
  1123. break;
  1124. }
  1125. if (!launch_slot_with_task(*slot, std::move(task))) {
  1126. SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
  1127. break;
  1128. }
  1129. } break;
  1130. case SERVER_TASK_TYPE_CANCEL:
  1131. {
  1132. // release slot linked with the task id
  1133. for (auto & slot : slots) {
  1134. if (slot.task && slot.task->id == task.id_target) {
  1135. slot.release();
  1136. break;
  1137. }
  1138. }
  1139. } break;
  1140. case SERVER_TASK_TYPE_NEXT_RESPONSE:
  1141. {
  1142. // do nothing
  1143. } break;
  1144. case SERVER_TASK_TYPE_METRICS:
  1145. {
  1146. json slots_data = json::array();
  1147. int n_idle_slots = 0;
  1148. int n_processing_slots = 0;
  1149. for (server_slot & slot : slots) {
  1150. json slot_data = slot.to_json(slots_debug == 0);
  1151. if (slot.is_processing()) {
  1152. n_processing_slots++;
  1153. } else {
  1154. n_idle_slots++;
  1155. }
  1156. slots_data.push_back(slot_data);
  1157. }
  1158. SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
  1159. auto res = std::make_unique<server_task_result_metrics>();
  1160. res->id = task.id;
  1161. res->slots_data = std::move(slots_data);
  1162. res->n_idle_slots = n_idle_slots;
  1163. res->n_processing_slots = n_processing_slots;
  1164. res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size();
  1165. res->t_start = metrics.t_start;
  1166. res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
  1167. res->t_prompt_processing_total = metrics.t_prompt_processing_total;
  1168. res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
  1169. res->t_tokens_generation_total = metrics.t_tokens_generation_total;
  1170. res->n_tokens_max = metrics.n_tokens_max;
  1171. res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
  1172. res->t_prompt_processing = metrics.t_prompt_processing;
  1173. res->n_tokens_predicted = metrics.n_tokens_predicted;
  1174. res->t_tokens_generation = metrics.t_tokens_generation;
  1175. res->n_decode_total = metrics.n_decode_total;
  1176. res->n_busy_slots_total = metrics.n_busy_slots_total;
  1177. if (task.metrics_reset_bucket) {
  1178. metrics.reset_bucket();
  1179. }
  1180. queue_results.send(std::move(res));
  1181. } break;
  1182. case SERVER_TASK_TYPE_SLOT_SAVE:
  1183. {
  1184. if (!check_no_mtmd(task.id)) {
  1185. break;
  1186. }
  1187. int id_slot = task.slot_action.slot_id;
  1188. server_slot * slot = get_slot_by_id(id_slot);
  1189. if (slot == nullptr) {
  1190. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1191. break;
  1192. }
  1193. if (slot->is_processing()) {
  1194. // if requested slot is unavailable, we defer this task for processing later
  1195. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1196. queue_tasks.defer(std::move(task));
  1197. break;
  1198. }
  1199. const size_t token_count = slot->prompt.tokens.size();
  1200. const int64_t t_start = ggml_time_us();
  1201. std::string filename = task.slot_action.filename;
  1202. std::string filepath = task.slot_action.filepath;
  1203. const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
  1204. const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
  1205. const int64_t t_end = ggml_time_us();
  1206. const double t_save_ms = (t_end - t_start) / 1000.0;
  1207. auto res = std::make_unique<server_task_result_slot_save_load>();
  1208. res->id = task.id;
  1209. res->id_slot = id_slot;
  1210. res->filename = filename;
  1211. res->is_save = true;
  1212. res->n_tokens = token_count;
  1213. res->n_bytes = nwrite;
  1214. res->t_ms = t_save_ms;
  1215. queue_results.send(std::move(res));
  1216. } break;
  1217. case SERVER_TASK_TYPE_SLOT_RESTORE:
  1218. {
  1219. if (!check_no_mtmd(task.id)) break;
  1220. int id_slot = task.slot_action.slot_id;
  1221. server_slot * slot = get_slot_by_id(id_slot);
  1222. if (slot == nullptr) {
  1223. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1224. break;
  1225. }
  1226. if (slot->is_processing()) {
  1227. // if requested slot is unavailable, we defer this task for processing later
  1228. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1229. queue_tasks.defer(std::move(task));
  1230. break;
  1231. }
  1232. const int64_t t_start = ggml_time_us();
  1233. std::string filename = task.slot_action.filename;
  1234. std::string filepath = task.slot_action.filepath;
  1235. llama_tokens tokens;
  1236. tokens.resize(slot->n_ctx);
  1237. size_t token_count = 0;
  1238. size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
  1239. if (nread == 0) {
  1240. slot->prompt.tokens.clear(); // KV may already been invalidated?
  1241. send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
  1242. break;
  1243. }
  1244. tokens.resize(token_count);
  1245. slot->prompt.tokens.clear();
  1246. slot->prompt.tokens.insert(tokens);
  1247. const int64_t t_end = ggml_time_us();
  1248. const double t_restore_ms = (t_end - t_start) / 1000.0;
  1249. auto res = std::make_unique<server_task_result_slot_save_load>();
  1250. res->id = task.id;
  1251. res->id_slot = id_slot;
  1252. res->filename = filename;
  1253. res->is_save = false;
  1254. res->n_tokens = token_count;
  1255. res->n_bytes = nread;
  1256. res->t_ms = t_restore_ms;
  1257. queue_results.send(std::move(res));
  1258. } break;
  1259. case SERVER_TASK_TYPE_SLOT_ERASE:
  1260. {
  1261. if (!check_no_mtmd(task.id)) {
  1262. break;
  1263. }
  1264. int id_slot = task.slot_action.slot_id;
  1265. server_slot * slot = get_slot_by_id(id_slot);
  1266. if (slot == nullptr) {
  1267. send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
  1268. break;
  1269. }
  1270. if (slot->is_processing()) {
  1271. // if requested slot is unavailable, we defer this task for processing later
  1272. SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
  1273. queue_tasks.defer(std::move(task));
  1274. break;
  1275. }
  1276. // Erase token cache
  1277. const size_t n_erased = slot->prompt.tokens.size();
  1278. clear_slot(*slot);
  1279. auto res = std::make_unique<server_task_result_slot_erase>();
  1280. res->id = task.id;
  1281. res->id_slot = id_slot;
  1282. res->n_erased = n_erased;
  1283. queue_results.send(std::move(res));
  1284. } break;
  1285. case SERVER_TASK_TYPE_SET_LORA:
  1286. {
  1287. params_base.lora_adapters = std::move(task.set_lora);
  1288. auto res = std::make_unique<server_task_result_apply_lora>();
  1289. res->id = task.id;
  1290. queue_results.send(std::move(res));
  1291. } break;
  1292. }
  1293. }
  1294. void update_slots() {
  1295. // check if all slots are idle
  1296. {
  1297. bool all_idle = true;
  1298. for (auto & slot : slots) {
  1299. if (slot.is_processing()) {
  1300. all_idle = false;
  1301. break;
  1302. }
  1303. }
  1304. if (all_idle) {
  1305. SRV_INF("%s", "all slots are idle\n");
  1306. return;
  1307. }
  1308. }
  1309. {
  1310. SRV_DBG("%s", "posting NEXT_RESPONSE\n");
  1311. server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
  1312. task.id = queue_tasks.get_new_id();
  1313. queue_tasks.post(std::move(task));
  1314. }
  1315. // apply context-shift if needed
  1316. // TODO: simplify and improve
  1317. for (server_slot & slot : slots) {
  1318. if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
  1319. if (!params_base.ctx_shift) {
  1320. // this check is redundant (for good)
  1321. // we should never get here, because generation should already stopped in process_token()
  1322. send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
  1323. slot.release();
  1324. continue;
  1325. }
  1326. if (mctx) {
  1327. // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
  1328. // we don't support ctx_shift because an image chunk may contains multiple tokens
  1329. GGML_ABORT("not supported by multimodal");
  1330. }
  1331. // Shift context
  1332. int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
  1333. if (add_bos_token) {
  1334. n_keep += 1;
  1335. }
  1336. n_keep = std::min(slot.n_ctx - 4, n_keep);
  1337. const int n_left = slot.prompt.n_tokens() - n_keep;
  1338. const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
  1339. SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
  1340. llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
  1341. llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
  1342. // add generated tokens to cache
  1343. // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
  1344. {
  1345. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1346. llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
  1347. for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
  1348. new_tokens[i - n_discard] = new_tokens[i];
  1349. }
  1350. new_tokens.resize(slot.prompt.tokens.size() - n_discard);
  1351. slot.prompt.tokens.clear();
  1352. slot.prompt.tokens.insert(new_tokens);
  1353. }
  1354. slot.truncated = true;
  1355. }
  1356. }
  1357. // start populating the batch for this iteration
  1358. common_batch_clear(batch);
  1359. // track if given slot can be batched with slots already in the batch
  1360. server_slot * slot_batched = nullptr;
  1361. auto accept_special_token = [&](server_slot & slot, llama_token token) {
  1362. return params_base.special ||
  1363. slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
  1364. };
  1365. // first, add sampled tokens from any ongoing sequences
  1366. for (auto & slot : slots) {
  1367. if (slot.state != SLOT_STATE_GENERATING) {
  1368. continue;
  1369. }
  1370. // check if we can batch this slot with the previous one
  1371. if (!slot_batched) {
  1372. slot_batched = &slot;
  1373. } else if (!slot_batched->can_batch_with(slot)) {
  1374. continue;
  1375. }
  1376. slot.i_batch = batch.n_tokens;
  1377. common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
  1378. slot.prompt.tokens.push_back(slot.sampled);
  1379. SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
  1380. slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
  1381. }
  1382. // process in chunks of params.n_batch
  1383. int32_t n_batch = llama_n_batch(ctx);
  1384. int32_t n_ubatch = llama_n_ubatch(ctx);
  1385. float alora_scale = -1.0f;
  1386. size_t alora_disabled_id = 0;
  1387. // next, batch any pending prompts without exceeding n_batch
  1388. if (params_base.cont_batching || batch.n_tokens == 0) {
  1389. for (auto & slot : slots) {
  1390. if (!slot.is_processing()) {
  1391. continue;
  1392. }
  1393. // check if we can batch this slot with the previous one
  1394. if (slot_batched && !slot_batched->can_batch_with(slot)) {
  1395. continue;
  1396. }
  1397. // this slot still has a prompt to be processed
  1398. if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
  1399. const auto & input_tokens = slot.task->tokens;
  1400. // TODO: maybe move branch to outside of this loop in the future
  1401. if (slot.state == SLOT_STATE_STARTED) {
  1402. slot.t_start_process_prompt = ggml_time_us();
  1403. slot.t_start_generation = 0;
  1404. slot.state = SLOT_STATE_PROCESSING_PROMPT;
  1405. SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
  1406. slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
  1407. // print prompt tokens (for debugging)
  1408. /*if (1) {
  1409. // first 16 tokens (avoid flooding logs)
  1410. for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
  1411. SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
  1412. }
  1413. } else {
  1414. // all
  1415. for (int i = 0; i < (int) input_tokens.size(); i++) {
  1416. SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
  1417. }
  1418. }*/
  1419. // keep track how many tokens we can reuse from the previous state
  1420. int n_past = 0;
  1421. // empty prompt passed -> release the slot and send empty response
  1422. if (input_tokens.empty()) {
  1423. SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
  1424. slot.print_timings();
  1425. send_final_response(slot);
  1426. slot.release();
  1427. continue;
  1428. }
  1429. // TODO: support memory-less logits computation
  1430. if (slot.need_logits() && !llama_get_memory(ctx)) {
  1431. send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
  1432. slot.release();
  1433. continue;
  1434. }
  1435. if (!slot.can_split()) {
  1436. if (slot.task->n_tokens() > n_ubatch) {
  1437. send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
  1438. slot.release();
  1439. continue;
  1440. }
  1441. if (slot.task->n_tokens() > slot.n_ctx) {
  1442. send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
  1443. slot.release();
  1444. continue;
  1445. }
  1446. } else {
  1447. if (slot.task->n_tokens() >= slot.n_ctx) {
  1448. send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
  1449. slot.release();
  1450. continue;
  1451. }
  1452. if (slot.task->params.cache_prompt) {
  1453. // reuse any previously computed tokens that are common with the new prompt
  1454. n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
  1455. // if there is an alora invoked, don't cache after the invocation start
  1456. if (slot.alora_invocation_start > 0) {
  1457. SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
  1458. n_past = std::min(n_past, slot.alora_invocation_start - 1);
  1459. }
  1460. // reuse chunks from the cached prompt by shifting their KV cache in the new position
  1461. if (params_base.n_cache_reuse > 0) {
  1462. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1463. size_t head_c = n_past; // cache
  1464. size_t head_p = n_past; // current prompt
  1465. if (mctx) {
  1466. // we should never reach this
  1467. GGML_ABORT("not supported by multimodal");
  1468. }
  1469. SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
  1470. while (head_c < slot.prompt.tokens.size() &&
  1471. head_p < input_tokens.size()) {
  1472. size_t n_match = 0;
  1473. while (head_c + n_match < slot.prompt.tokens.size() &&
  1474. head_p + n_match < input_tokens.size() &&
  1475. slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
  1476. n_match++;
  1477. }
  1478. if (n_match >= (size_t) params_base.n_cache_reuse) {
  1479. SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
  1480. //for (size_t i = head_p; i < head_p + n_match; i++) {
  1481. // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
  1482. //}
  1483. const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
  1484. llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c);
  1485. llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift);
  1486. for (size_t i = 0; i < n_match; i++) {
  1487. slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
  1488. n_past++;
  1489. }
  1490. head_c += n_match;
  1491. head_p += n_match;
  1492. } else {
  1493. head_c += 1;
  1494. }
  1495. }
  1496. SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
  1497. }
  1498. } else {
  1499. // if we don't cache the prompt, we have to remove all previous tokens
  1500. n_past = 0;
  1501. }
  1502. // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
  1503. const auto n_swa = std::max(1, llama_model_n_swa(model));
  1504. // the largest pos_min required for a checkpoint to be useful
  1505. const auto pos_min_thold = std::max(0, n_past - n_swa);
  1506. // note: disallow with mtmd contexts for now
  1507. // https://github.com/ggml-org/llama.cpp/issues/17043
  1508. if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
  1509. const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
  1510. if (pos_min == -1) {
  1511. SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
  1512. GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
  1513. }
  1514. // when the prompt prefix does not match, print the tokens around the mismatch
  1515. // this is useful for debugging prompt caching
  1516. if (slots_debug) {
  1517. const int np0 = std::max<int>(n_past - 4, 0);
  1518. const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
  1519. std::stringstream ss0;
  1520. std::stringstream ss1;
  1521. std::stringstream st0;
  1522. std::stringstream st1;
  1523. ss0 << "old: ... ";
  1524. ss1 << "new: ... ";
  1525. for (int i = np0; i < np1; i++) {
  1526. if (i == n_past) {
  1527. ss0 << " | ";
  1528. ss1 << " | ";
  1529. }
  1530. {
  1531. const auto token = slot.prompt.tokens[i];
  1532. const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
  1533. ss0 << piece;
  1534. st0 << std::setw(8) << token;
  1535. }
  1536. {
  1537. const auto token = slot.task->tokens[i];
  1538. const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
  1539. ss1 << piece;
  1540. st1 << std::setw(8) << token;
  1541. }
  1542. }
  1543. SLT_WRN(slot, "%s\n", ss0.str().c_str());
  1544. SLT_WRN(slot, "%s\n", ss1.str().c_str());
  1545. SLT_WRN(slot, "%s\n", st0.str().c_str());
  1546. SLT_WRN(slot, "%s\n", st1.str().c_str());
  1547. }
  1548. if (pos_min > pos_min_thold) {
  1549. // TODO: support can be added in the future when corresponding vision models get released
  1550. GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
  1551. SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
  1552. // search for a context checkpoint
  1553. const auto it = std::find_if(
  1554. slot.prompt.checkpoints.rbegin(),
  1555. slot.prompt.checkpoints.rend(),
  1556. [&](const auto & cur) {
  1557. // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
  1558. return cur.pos_min < pos_min_thold;
  1559. }
  1560. );
  1561. bool do_reset = it == slot.prompt.checkpoints.rend();
  1562. if (!do_reset) {
  1563. // restore the context checkpoint
  1564. const size_t checkpoint_size = it->data.size();
  1565. const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1566. if (n != checkpoint_size) {
  1567. SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
  1568. do_reset = true;
  1569. //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
  1570. } else {
  1571. n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
  1572. SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
  1573. }
  1574. }
  1575. if (do_reset) {
  1576. SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
  1577. "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
  1578. n_past = 0;
  1579. }
  1580. }
  1581. }
  1582. {
  1583. // erase any checkpoints with pos_min > pos_min_thold
  1584. for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
  1585. const auto & cur = *it;
  1586. if (cur.pos_min > pos_min_thold) {
  1587. SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
  1588. it = slot.prompt.checkpoints.erase(it);
  1589. } else {
  1590. ++it;
  1591. }
  1592. }
  1593. }
  1594. }
  1595. // [TAG_PROMPT_LOGITS]
  1596. if (n_past == slot.task->n_tokens() && n_past > 0) {
  1597. SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
  1598. n_past--;
  1599. SLT_WRN(slot, "n_past was set to %d\n", n_past);
  1600. }
  1601. slot.n_prompt_tokens_cache = n_past;
  1602. slot.n_prompt_tokens_processed = 0;
  1603. slot.prompt.tokens.keep_first(n_past);
  1604. }
  1605. if (!slot.can_split()) {
  1606. // cannot fit the prompt in the current batch - will try next iter
  1607. if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
  1608. continue;
  1609. }
  1610. }
  1611. // truncate any tokens that are beyond n_past for this slot
  1612. const llama_pos p0 = slot.prompt.tokens.pos_next();
  1613. SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
  1614. if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
  1615. SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
  1616. clear_slot(slot);
  1617. // there is no common part left
  1618. slot.n_prompt_tokens_cache = 0;
  1619. }
  1620. // check if we should process the image
  1621. if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
  1622. // process the image
  1623. size_t n_tokens_out = 0;
  1624. int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
  1625. if (res != 0) {
  1626. SLT_ERR(slot, "failed to process image, res = %d\n", res);
  1627. send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
  1628. slot.release();
  1629. continue;
  1630. }
  1631. slot.n_prompt_tokens_processed += n_tokens_out;
  1632. // add the image chunk to cache
  1633. {
  1634. const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
  1635. slot.prompt.tokens.push_back(chunk.get()); // copy
  1636. }
  1637. }
  1638. // If using an alora, there may be uncached tokens that come
  1639. // before the invocation sequence. When this happens, the
  1640. // tokens before the invocation sequence need to be
  1641. // processed without the adapter in a separate batch, then
  1642. // the adapter needs to be enabled for the remaining tokens.
  1643. if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
  1644. SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
  1645. const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
  1646. GGML_ASSERT(enabled_loras.size() == 1);
  1647. alora_scale = slot.lora[enabled_loras[0]].scale;
  1648. slot.lora[enabled_loras[0]].scale = 0.0f;
  1649. alora_disabled_id = enabled_loras[0];
  1650. }
  1651. bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
  1652. // make checkpoints only for completion tasks
  1653. do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
  1654. // make a checkpoint of the parts of the memory that cannot be rolled back.
  1655. // checkpoints are created only if:
  1656. // - the model uses SWA and we are not using `swa_full`
  1657. // - the model architecture is marked as recurrent or hybrid
  1658. //
  1659. // TODO: try to make this conditional on the context or the memory module, instead of the model type
  1660. do_checkpoint = do_checkpoint && (
  1661. llama_model_is_recurrent(model) ||
  1662. llama_model_is_hybrid(model) ||
  1663. (llama_model_n_swa(model) > 0 && !params_base.swa_full)
  1664. );
  1665. // add prompt tokens for processing in the current batch
  1666. while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
  1667. // get next token to process
  1668. llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
  1669. if (cur_tok == LLAMA_TOKEN_NULL) {
  1670. break; // end of text chunk
  1671. }
  1672. // if this is an alora request with pre-invocation
  1673. // tokens that are not cached, we need to stop filling
  1674. // this batch at those pre-invocation tokens.
  1675. if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
  1676. SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
  1677. break;
  1678. }
  1679. // embedding requires all tokens in the batch to be output
  1680. common_batch_add(batch,
  1681. cur_tok,
  1682. slot.prompt.tokens.pos_next(),
  1683. { slot.id },
  1684. slot.need_embd());
  1685. slot.prompt.tokens.push_back(cur_tok);
  1686. slot.n_prompt_tokens_processed++;
  1687. // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
  1688. if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
  1689. break;
  1690. }
  1691. }
  1692. // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
  1693. SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
  1694. // entire prompt has been processed
  1695. if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
  1696. slot.state = SLOT_STATE_DONE_PROMPT;
  1697. GGML_ASSERT(batch.n_tokens > 0);
  1698. common_sampler_reset(slot.smpl);
  1699. // Process all prompt tokens through sampler system
  1700. for (int i = 0; i < slot.task->n_tokens(); ++i) {
  1701. llama_token id = input_tokens[i];
  1702. if (id != LLAMA_TOKEN_NULL) {
  1703. common_sampler_accept(slot.smpl, id, false);
  1704. }
  1705. }
  1706. // extract the logits only for the last token
  1707. batch.logits[batch.n_tokens - 1] = true;
  1708. slot.n_decoded = 0;
  1709. slot.i_batch = batch.n_tokens - 1;
  1710. SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
  1711. const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
  1712. const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
  1713. // no need for empty or small checkpoints
  1714. do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
  1715. // no need to create checkpoints that are too close together
  1716. do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
  1717. if (do_checkpoint) {
  1718. while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
  1719. // make room for the new checkpoint, if needed
  1720. const auto & cur = slot.prompt.checkpoints.front();
  1721. SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
  1722. cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
  1723. slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
  1724. }
  1725. const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1726. auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
  1727. /*.pos_min = */ pos_min,
  1728. /*.pos_max = */ pos_max,
  1729. /*.data = */ std::vector<uint8_t>(checkpoint_size),
  1730. });
  1731. llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
  1732. SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
  1733. (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
  1734. }
  1735. }
  1736. }
  1737. if (!slot_batched) {
  1738. slot_batched = &slot;
  1739. }
  1740. if (batch.n_tokens >= n_batch) {
  1741. break;
  1742. }
  1743. }
  1744. }
  1745. if (batch.n_tokens == 0) {
  1746. SRV_WRN("%s", "no tokens to decode\n");
  1747. return;
  1748. }
  1749. SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
  1750. if (slot_batched) {
  1751. // apply lora, only need to do it once per batch
  1752. common_set_adapter_lora(ctx, slot_batched->lora);
  1753. // if the lora is temporarily disabled for an alora, re-enable it
  1754. // for next time
  1755. if (alora_scale > 0.0f) {
  1756. SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
  1757. slot_batched->lora[alora_disabled_id].scale = alora_scale;
  1758. }
  1759. llama_set_embeddings(ctx, slot_batched->need_embd());
  1760. }
  1761. int32_t i_next = 0;
  1762. // process the created batch of tokens
  1763. for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
  1764. const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
  1765. llama_batch batch_view = {
  1766. n_tokens,
  1767. batch.token + i,
  1768. nullptr,
  1769. batch.pos + i,
  1770. batch.n_seq_id + i,
  1771. batch.seq_id + i,
  1772. batch.logits + i,
  1773. };
  1774. const int ret = llama_decode(ctx, batch_view);
  1775. metrics.on_decoded(slots);
  1776. if (ret != 0) {
  1777. {
  1778. std::string err;
  1779. if (n_batch == 1 && ret == 1) {
  1780. // TODO: try to terminate only the largest active slot/sequence and continue with the rest
  1781. // need to remove the tokens from the current batch too
  1782. err = "Context size has been exceeded.";
  1783. }
  1784. if (ret == -1) {
  1785. err = "Invalid input batch.";
  1786. }
  1787. if (ret < -1) {
  1788. // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
  1789. err = "Compute error.";
  1790. }
  1791. // TODO: handle ret == 2 (abort) when we start aborting
  1792. if (!err.empty()) {
  1793. SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
  1794. for (auto & slot : slots) {
  1795. if (slot.is_processing()) {
  1796. send_error(slot, err);
  1797. slot.release();
  1798. // note: it's complicated to keep track of how much of the current batch has been
  1799. // processed before the error occurred, so we simply clear the entire context
  1800. clear_slot(slot);
  1801. }
  1802. }
  1803. break;
  1804. }
  1805. }
  1806. // retry with half the batch size to try to find a free slot in the KV cache
  1807. if (!try_clear_idle_slots()) {
  1808. n_batch /= 2;
  1809. }
  1810. SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
  1811. continue; // continue loop of n_batch
  1812. }
  1813. // move the head of the batch forward with the number of tokens we just processed
  1814. i_next = i + n_tokens;
  1815. // on successful decode, restore the original batch size
  1816. n_batch = llama_n_batch(ctx);
  1817. for (auto & slot : slots) {
  1818. // optionally send prompt processing progress
  1819. if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
  1820. if (slot.task->params.stream && slot.task->params.return_progress) {
  1821. send_partial_response(slot, {}, true);
  1822. }
  1823. }
  1824. if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
  1825. continue; // continue loop of slots
  1826. }
  1827. if (slot.state == SLOT_STATE_DONE_PROMPT) {
  1828. if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
  1829. // prompt evaluated for embedding
  1830. send_embedding(slot, batch_view);
  1831. slot.release();
  1832. slot.i_batch = -1;
  1833. continue; // continue loop of slots
  1834. }
  1835. if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
  1836. send_rerank(slot, batch_view);
  1837. slot.release();
  1838. slot.i_batch = -1;
  1839. continue; // continue loop of slots
  1840. }
  1841. // prompt evaluated for next-token prediction
  1842. slot.state = SLOT_STATE_GENERATING;
  1843. } else if (slot.state != SLOT_STATE_GENERATING) {
  1844. continue; // continue loop of slots
  1845. }
  1846. const int tok_idx = slot.i_batch - i;
  1847. llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
  1848. slot.i_batch = -1;
  1849. common_sampler_accept(slot.smpl, id, true);
  1850. slot.n_decoded += 1;
  1851. const int64_t t_current = ggml_time_us();
  1852. if (slot.n_decoded == 1) {
  1853. slot.t_start_generation = t_current;
  1854. slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
  1855. metrics.on_prompt_eval(slot);
  1856. }
  1857. slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
  1858. completion_token_output result;
  1859. result.tok = id;
  1860. result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
  1861. result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
  1862. if (slot.task->params.sampling.n_probs > 0) {
  1863. populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
  1864. }
  1865. if (!process_token(result, slot)) {
  1866. // release slot because of stop condition
  1867. slot.print_timings();
  1868. send_final_response(slot);
  1869. metrics.on_prediction(slot);
  1870. slot.release();
  1871. continue;
  1872. }
  1873. }
  1874. // do speculative decoding
  1875. // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
  1876. // perform the speculative drafting for all sequences at the same time in a single batch
  1877. for (auto & slot : slots) {
  1878. if (!slot.is_processing() || !slot.can_speculate()) {
  1879. continue;
  1880. }
  1881. if (slot.state != SLOT_STATE_GENERATING) {
  1882. continue;
  1883. }
  1884. if (mctx) {
  1885. // we should never reach this, as speculative is automatically disabled if mmproj is loaded
  1886. GGML_ABORT("not supported by multimodal");
  1887. }
  1888. // determine the max draft that fits the current slot state
  1889. int n_draft_max = slot.task->params.speculative.n_max;
  1890. // note: slot.prompt is not yet expanded with the `id` token sampled above
  1891. // also, need to leave space for 1 extra token to allow context shifts
  1892. n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
  1893. if (slot.n_remaining > 0) {
  1894. n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
  1895. }
  1896. SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
  1897. if (n_draft_max < slot.task->params.speculative.n_min) {
  1898. SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
  1899. continue;
  1900. }
  1901. llama_token id = slot.sampled;
  1902. struct common_speculative_params params_spec;
  1903. params_spec.n_draft = n_draft_max;
  1904. params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
  1905. params_spec.p_min = slot.task->params.speculative.p_min;
  1906. const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
  1907. llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
  1908. // ignore small drafts
  1909. if (slot.task->params.speculative.n_min > (int) draft.size()) {
  1910. SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
  1911. continue;
  1912. }
  1913. // keep track of total number of drafted tokens tested
  1914. slot.n_draft_total += draft.size();
  1915. // construct the speculation batch
  1916. common_batch_clear(slot.batch_spec);
  1917. common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
  1918. for (size_t i = 0; i < draft.size(); ++i) {
  1919. common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
  1920. }
  1921. SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
  1922. llama_decode(ctx, slot.batch_spec);
  1923. // the accepted tokens from the speculation
  1924. const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
  1925. slot.n_decoded += ids.size();
  1926. // update how many tokens out of those tested were accepted
  1927. slot.n_draft_accepted += ids.size() - 1;
  1928. slot.prompt.tokens.push_back(id);
  1929. slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
  1930. llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
  1931. for (size_t i = 0; i < ids.size(); ++i) {
  1932. completion_token_output result;
  1933. result.tok = ids[i];
  1934. result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
  1935. result.prob = 1.0f; // set later
  1936. // TODO: set result.probs
  1937. if (!process_token(result, slot)) {
  1938. slot.print_timings();
  1939. send_final_response(slot);
  1940. metrics.on_prediction(slot);
  1941. slot.release();
  1942. break;
  1943. }
  1944. }
  1945. SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
  1946. }
  1947. }
  1948. SRV_DBG("%s", "run slots completed\n");
  1949. }
  1950. json model_meta() const {
  1951. return json {
  1952. {"vocab_type", llama_vocab_type (vocab)},
  1953. {"n_vocab", llama_vocab_n_tokens (vocab)},
  1954. {"n_ctx_train", llama_model_n_ctx_train(model)},
  1955. {"n_embd", llama_model_n_embd (model)},
  1956. {"n_params", llama_model_n_params (model)},
  1957. {"size", llama_model_size (model)},
  1958. };
  1959. }
  1960. int get_slot_n_ctx() {
  1961. return slots.back().n_ctx;
  1962. }
  1963. };
  1964. //
  1965. // server_context (public API)
  1966. //
  1967. server_context::server_context() : impl(new server_context_impl()) {}
  1968. server_context::~server_context() = default;
  1969. void server_context::init() {
  1970. impl->init();
  1971. }
  1972. bool server_context::load_model(const common_params & params) {
  1973. return impl->load_model(params);
  1974. }
  1975. void server_context::start_loop() {
  1976. impl->queue_tasks.start_loop();
  1977. }
  1978. void server_context::terminate() {
  1979. impl->queue_tasks.terminate();
  1980. }
  1981. llama_context * server_context::get_llama_context() const {
  1982. return impl->ctx;
  1983. }
  1984. std::pair<server_queue &, server_response &> server_context::get_queues() {
  1985. return { impl->queue_tasks, impl->queue_results };
  1986. }
  1987. // generator-like API for HTTP response generation
  1988. struct server_res_generator : server_http_res {
  1989. server_response_reader rd;
  1990. server_res_generator(server_context_impl & ctx_server)
  1991. : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
  1992. void ok(const json & response_data) {
  1993. status = 200;
  1994. data = safe_json_to_str(response_data);
  1995. }
  1996. void error(const json & error_data) {
  1997. status = json_value(error_data, "code", 500);
  1998. data = safe_json_to_str({{ "error", error_data }});
  1999. }
  2000. };
  2001. //
  2002. // server_routes
  2003. //
  2004. static std::unique_ptr<server_res_generator> handle_completions_impl(
  2005. server_context_impl & ctx_server,
  2006. server_task_type type,
  2007. const json & data,
  2008. const std::vector<raw_buffer> & files,
  2009. const std::function<bool()> & should_stop,
  2010. task_response_type res_type) {
  2011. GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
  2012. auto res = std::make_unique<server_res_generator>(ctx_server);
  2013. auto completion_id = gen_chatcmplid();
  2014. auto & rd = res->rd;
  2015. try {
  2016. std::vector<server_task> tasks;
  2017. const auto & prompt = data.at("prompt");
  2018. // TODO: this log can become very long, put it behind a flag or think about a more compact format
  2019. //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
  2020. // process prompt
  2021. std::vector<server_tokens> inputs;
  2022. if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
  2023. // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
  2024. inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
  2025. } else {
  2026. // Everything else, including multimodal completions.
  2027. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  2028. }
  2029. tasks.reserve(inputs.size());
  2030. for (size_t i = 0; i < inputs.size(); i++) {
  2031. server_task task = server_task(type);
  2032. task.id = ctx_server.queue_tasks.get_new_id();
  2033. task.index = i;
  2034. task.tokens = std::move(inputs[i]);
  2035. task.params = server_task::params_from_json_cmpl(
  2036. ctx_server.ctx,
  2037. ctx_server.params_base,
  2038. data);
  2039. task.id_slot = json_value(data, "id_slot", -1);
  2040. // OAI-compat
  2041. task.params.res_type = res_type;
  2042. task.params.oaicompat_cmpl_id = completion_id;
  2043. // oaicompat_model is already populated by params_from_json_cmpl
  2044. tasks.push_back(std::move(task));
  2045. }
  2046. rd.post_tasks(std::move(tasks));
  2047. } catch (const std::exception & e) {
  2048. res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
  2049. return res;
  2050. }
  2051. bool stream = json_value(data, "stream", false);
  2052. if (!stream) {
  2053. // non-stream, wait for the results
  2054. auto all_results = rd.wait_for_all(should_stop);
  2055. if (all_results.is_terminated) {
  2056. return res; // connection is closed
  2057. } else if (all_results.error) {
  2058. res->error(all_results.error->to_json());
  2059. return res;
  2060. } else {
  2061. json arr = json::array();
  2062. for (auto & res : all_results.results) {
  2063. GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
  2064. arr.push_back(res->to_json());
  2065. }
  2066. // if single request, return single object instead of array
  2067. res->ok(arr.size() == 1 ? arr[0] : arr);
  2068. }
  2069. } else {
  2070. // in streaming mode, the first error must be treated as non-stream response
  2071. // this is to match the OAI API behavior
  2072. // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
  2073. server_task_result_ptr first_result = rd.next(should_stop);
  2074. if (first_result == nullptr) {
  2075. return res; // connection is closed
  2076. } else if (first_result->is_error()) {
  2077. res->error(first_result->to_json());
  2078. return res;
  2079. } else {
  2080. GGML_ASSERT(
  2081. dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
  2082. || dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
  2083. );
  2084. }
  2085. // next responses are streamed
  2086. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2087. res->data = format_anthropic_sse(first_result->to_json());
  2088. } else {
  2089. res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
  2090. }
  2091. res->status = 200;
  2092. res->content_type = "text/event-stream";
  2093. res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
  2094. if (should_stop()) {
  2095. SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
  2096. return false; // should_stop condition met
  2097. }
  2098. if (!res_this->data.empty()) {
  2099. // flush the first chunk
  2100. output = std::move(res_this->data);
  2101. res_this->data.clear();
  2102. return true;
  2103. }
  2104. server_response_reader & rd = res_this->rd;
  2105. // check if there is more data
  2106. if (!rd.has_next()) {
  2107. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2108. // Anthropic doesn't send [DONE], message_stop was already sent
  2109. output = "";
  2110. } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
  2111. output = "data: [DONE]\n\n";
  2112. } else {
  2113. output = "";
  2114. }
  2115. SRV_DBG("%s", "all results received, terminating stream\n");
  2116. return false; // no more data, terminate
  2117. }
  2118. // receive subsequent results
  2119. auto result = rd.next(should_stop);
  2120. if (result == nullptr) {
  2121. SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
  2122. return false; // should_stop condition met
  2123. }
  2124. // send the results
  2125. json res_json = result->to_json();
  2126. if (result->is_error()) {
  2127. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2128. output = format_anthropic_sse({
  2129. {"event", "error"},
  2130. {"data", res_json},
  2131. });
  2132. } else {
  2133. output = format_oai_sse(json {{ "error", res_json }});
  2134. }
  2135. SRV_DBG("%s", "error received during streaming, terminating stream\n");
  2136. return false; // terminate on error
  2137. } else {
  2138. GGML_ASSERT(
  2139. dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
  2140. || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
  2141. );
  2142. if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
  2143. output = format_anthropic_sse(res_json);
  2144. } else {
  2145. output = format_oai_sse(res_json);
  2146. }
  2147. }
  2148. // has next data, continue
  2149. return true;
  2150. };
  2151. }
  2152. return res;
  2153. }
  2154. void server_routes::init_routes() {
  2155. this->get_health = [this](const server_http_req &) {
  2156. // error and loading states are handled by middleware
  2157. auto res = std::make_unique<server_res_generator>(ctx_server);
  2158. res->ok({{"status", "ok"}});
  2159. return res;
  2160. };
  2161. this->get_metrics = [this](const server_http_req &) {
  2162. auto res = std::make_unique<server_res_generator>(ctx_server);
  2163. if (!params.endpoint_metrics) {
  2164. res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
  2165. return res;
  2166. }
  2167. // request slots data using task queue
  2168. // TODO: use server_response_reader
  2169. int task_id = ctx_server.queue_tasks.get_new_id();
  2170. {
  2171. server_task task(SERVER_TASK_TYPE_METRICS);
  2172. task.id = task_id;
  2173. ctx_server.queue_results.add_waiting_task_id(task_id);
  2174. ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
  2175. }
  2176. // get the result
  2177. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2178. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2179. if (result->is_error()) {
  2180. res->error(result->to_json());
  2181. return res;
  2182. }
  2183. // TODO: get rid of this dynamic_cast
  2184. auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
  2185. GGML_ASSERT(res_task != nullptr);
  2186. // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
  2187. json all_metrics_def = json {
  2188. {"counter", {{
  2189. {"name", "prompt_tokens_total"},
  2190. {"help", "Number of prompt tokens processed."},
  2191. {"value", (uint64_t) res_task->n_prompt_tokens_processed_total}
  2192. }, {
  2193. {"name", "prompt_seconds_total"},
  2194. {"help", "Prompt process time"},
  2195. {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3}
  2196. }, {
  2197. {"name", "tokens_predicted_total"},
  2198. {"help", "Number of generation tokens processed."},
  2199. {"value", (uint64_t) res_task->n_tokens_predicted_total}
  2200. }, {
  2201. {"name", "tokens_predicted_seconds_total"},
  2202. {"help", "Predict process time"},
  2203. {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3}
  2204. }, {
  2205. {"name", "n_decode_total"},
  2206. {"help", "Total number of llama_decode() calls"},
  2207. {"value", res_task->n_decode_total}
  2208. }, {
  2209. {"name", "n_tokens_max"},
  2210. {"help", "Largest observed n_tokens."},
  2211. {"value", res_task->n_tokens_max}
  2212. }, {
  2213. {"name", "n_busy_slots_per_decode"},
  2214. {"help", "Average number of busy slots per llama_decode() call"},
  2215. {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)}
  2216. }}},
  2217. {"gauge", {{
  2218. {"name", "prompt_tokens_seconds"},
  2219. {"help", "Average prompt throughput in tokens/s."},
  2220. {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
  2221. },{
  2222. {"name", "predicted_tokens_seconds"},
  2223. {"help", "Average generation throughput in tokens/s."},
  2224. {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
  2225. },{
  2226. {"name", "requests_processing"},
  2227. {"help", "Number of requests processing."},
  2228. {"value", (uint64_t) res_task->n_processing_slots}
  2229. },{
  2230. {"name", "requests_deferred"},
  2231. {"help", "Number of requests deferred."},
  2232. {"value", (uint64_t) res_task->n_tasks_deferred}
  2233. }}}
  2234. };
  2235. std::stringstream prometheus;
  2236. for (const auto & el : all_metrics_def.items()) {
  2237. const auto & type = el.key();
  2238. const auto & metrics_def = el.value();
  2239. for (const auto & metric_def : metrics_def) {
  2240. const std::string name = metric_def.at("name");
  2241. const std::string help = metric_def.at("help");
  2242. auto value = json_value(metric_def, "value", 0.);
  2243. prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
  2244. << "# TYPE llamacpp:" << name << " " << type << "\n"
  2245. << "llamacpp:" << name << " " << value << "\n";
  2246. }
  2247. }
  2248. res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
  2249. res->content_type = "text/plain; version=0.0.4";
  2250. res->status = 200;
  2251. res->data = prometheus.str();
  2252. return res;
  2253. };
  2254. this->get_slots = [this](const server_http_req & req) {
  2255. auto res = std::make_unique<server_res_generator>(ctx_server);
  2256. if (!params.endpoint_slots) {
  2257. res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
  2258. return res;
  2259. }
  2260. // request slots data using task queue
  2261. int task_id = ctx_server.queue_tasks.get_new_id();
  2262. {
  2263. server_task task(SERVER_TASK_TYPE_METRICS);
  2264. task.id = task_id;
  2265. ctx_server.queue_results.add_waiting_task_id(task_id);
  2266. ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
  2267. }
  2268. // get the result
  2269. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2270. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2271. if (result->is_error()) {
  2272. res->error(result->to_json());
  2273. return res;
  2274. }
  2275. // TODO: get rid of this dynamic_cast
  2276. auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
  2277. GGML_ASSERT(res_task != nullptr);
  2278. // optionally return "fail_on_no_slot" error
  2279. if (!req.get_param("fail_on_no_slot").empty()) {
  2280. if (res_task->n_idle_slots == 0) {
  2281. res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
  2282. return res;
  2283. }
  2284. }
  2285. res->ok(res_task->slots_data);
  2286. return res;
  2287. };
  2288. this->post_slots = [this](const server_http_req & req) {
  2289. auto res = std::make_unique<server_res_generator>(ctx_server);
  2290. if (params.slot_save_path.empty()) {
  2291. res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
  2292. return res;
  2293. }
  2294. std::string id_slot_str = req.get_param("id_slot");
  2295. int id_slot;
  2296. try {
  2297. id_slot = std::stoi(id_slot_str);
  2298. } catch (const std::exception &) {
  2299. res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
  2300. return res;
  2301. }
  2302. std::string action = req.get_param("action");
  2303. if (action == "save") {
  2304. return handle_slots_save(req, id_slot);
  2305. } else if (action == "restore") {
  2306. return handle_slots_restore(req, id_slot);
  2307. } else if (action == "erase") {
  2308. return handle_slots_erase(req, id_slot);
  2309. } else {
  2310. res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
  2311. return res;
  2312. }
  2313. };
  2314. this->get_props = [this](const server_http_req &) {
  2315. auto res = std::make_unique<server_res_generator>(ctx_server);
  2316. json default_generation_settings_for_props;
  2317. {
  2318. task_params params;
  2319. params.sampling = ctx_server.params_base.sampling;
  2320. default_generation_settings_for_props = json {
  2321. {"params", params.to_json(true)},
  2322. {"n_ctx", ctx_server.get_slot_n_ctx()},
  2323. };
  2324. }
  2325. // this endpoint is publicly available, please only return what is safe to be exposed
  2326. json data = {
  2327. { "default_generation_settings", default_generation_settings_for_props },
  2328. { "total_slots", ctx_server.params_base.n_parallel },
  2329. { "model_alias", ctx_server.params_base.model_alias },
  2330. { "model_path", ctx_server.params_base.model.path },
  2331. { "modalities", json {
  2332. {"vision", ctx_server.oai_parser_opt.allow_image},
  2333. {"audio", ctx_server.oai_parser_opt.allow_audio},
  2334. } },
  2335. { "endpoint_slots", params.endpoint_slots },
  2336. { "endpoint_props", params.endpoint_props },
  2337. { "endpoint_metrics", params.endpoint_metrics },
  2338. { "webui", params.webui },
  2339. { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
  2340. { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
  2341. { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
  2342. { "build_info", build_info },
  2343. };
  2344. if (ctx_server.params_base.use_jinja) {
  2345. if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
  2346. data["chat_template_tool_use"] = tool_use_src;
  2347. }
  2348. }
  2349. res->ok(data);
  2350. return res;
  2351. };
  2352. this->post_props = [this](const server_http_req &) {
  2353. auto res = std::make_unique<server_res_generator>(ctx_server);
  2354. if (!params.endpoint_props) {
  2355. res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
  2356. return res;
  2357. }
  2358. // update any props here
  2359. res->ok({{ "success", true }});
  2360. return res;
  2361. };
  2362. this->get_api_show = [this](const server_http_req &) {
  2363. auto res = std::make_unique<server_res_generator>(ctx_server);
  2364. bool has_mtmd = ctx_server.mctx != nullptr;
  2365. json data = {
  2366. {
  2367. "template", common_chat_templates_source(ctx_server.chat_templates.get()),
  2368. },
  2369. {
  2370. "model_info", {
  2371. { "llama.context_length", ctx_server.get_slot_n_ctx() },
  2372. }
  2373. },
  2374. {"modelfile", ""},
  2375. {"parameters", ""},
  2376. {"template", common_chat_templates_source(ctx_server.chat_templates.get())},
  2377. {"details", {
  2378. {"parent_model", ""},
  2379. {"format", "gguf"},
  2380. {"family", ""},
  2381. {"families", {""}},
  2382. {"parameter_size", ""},
  2383. {"quantization_level", ""}
  2384. }},
  2385. {"model_info", ""},
  2386. {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
  2387. };
  2388. res->ok(data);
  2389. return res;
  2390. };
  2391. this->post_infill = [this](const server_http_req & req) {
  2392. auto res = std::make_unique<server_res_generator>(ctx_server);
  2393. // check model compatibility
  2394. std::string err;
  2395. if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2396. err += "prefix token is missing. ";
  2397. }
  2398. if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2399. err += "suffix token is missing. ";
  2400. }
  2401. if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
  2402. err += "middle token is missing. ";
  2403. }
  2404. if (!err.empty()) {
  2405. res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
  2406. return res;
  2407. }
  2408. // validate input
  2409. json data = json::parse(req.body);
  2410. if (data.contains("prompt") && !data.at("prompt").is_string()) {
  2411. // prompt is optional
  2412. res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
  2413. }
  2414. if (!data.contains("input_prefix")) {
  2415. res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
  2416. }
  2417. if (!data.contains("input_suffix")) {
  2418. res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
  2419. }
  2420. if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
  2421. // input_extra is optional
  2422. res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
  2423. return res;
  2424. }
  2425. json input_extra = json_value(data, "input_extra", json::array());
  2426. for (const auto & chunk : input_extra) {
  2427. // { "text": string, "filename": string }
  2428. if (!chunk.contains("text") || !chunk.at("text").is_string()) {
  2429. res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
  2430. return res;
  2431. }
  2432. // filename is optional
  2433. if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
  2434. res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
  2435. return res;
  2436. }
  2437. }
  2438. data["input_extra"] = input_extra; // default to empty array if it's not exist
  2439. std::string prompt = json_value(data, "prompt", std::string());
  2440. std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
  2441. SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
  2442. data["prompt"] = format_prompt_infill(
  2443. ctx_server.vocab,
  2444. data.at("input_prefix"),
  2445. data.at("input_suffix"),
  2446. data.at("input_extra"),
  2447. ctx_server.params_base.n_batch,
  2448. ctx_server.params_base.n_predict,
  2449. ctx_server.get_slot_n_ctx(),
  2450. ctx_server.params_base.spm_infill,
  2451. tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
  2452. );
  2453. std::vector<raw_buffer> files; // dummy
  2454. return handle_completions_impl(
  2455. ctx_server,
  2456. SERVER_TASK_TYPE_INFILL,
  2457. data,
  2458. files,
  2459. req.should_stop,
  2460. TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
  2461. };
  2462. this->post_completions = [this](const server_http_req & req) {
  2463. std::vector<raw_buffer> files; // dummy
  2464. const json body = json::parse(req.body);
  2465. return handle_completions_impl(
  2466. ctx_server,
  2467. SERVER_TASK_TYPE_COMPLETION,
  2468. body,
  2469. files,
  2470. req.should_stop,
  2471. TASK_RESPONSE_TYPE_NONE);
  2472. };
  2473. this->post_completions_oai = [this](const server_http_req & req) {
  2474. std::vector<raw_buffer> files; // dummy
  2475. const json body = json::parse(req.body);
  2476. return handle_completions_impl(
  2477. ctx_server,
  2478. SERVER_TASK_TYPE_COMPLETION,
  2479. body,
  2480. files,
  2481. req.should_stop,
  2482. TASK_RESPONSE_TYPE_OAI_CMPL);
  2483. };
  2484. this->post_chat_completions = [this](const server_http_req & req) {
  2485. std::vector<raw_buffer> files;
  2486. json body = json::parse(req.body);
  2487. json body_parsed = oaicompat_chat_params_parse(
  2488. body,
  2489. ctx_server.oai_parser_opt,
  2490. files);
  2491. return handle_completions_impl(
  2492. ctx_server,
  2493. SERVER_TASK_TYPE_COMPLETION,
  2494. body_parsed,
  2495. files,
  2496. req.should_stop,
  2497. TASK_RESPONSE_TYPE_OAI_CHAT);
  2498. };
  2499. this->post_anthropic_messages = [this](const server_http_req & req) {
  2500. std::vector<raw_buffer> files;
  2501. json body = convert_anthropic_to_oai(json::parse(req.body));
  2502. json body_parsed = oaicompat_chat_params_parse(
  2503. body,
  2504. ctx_server.oai_parser_opt,
  2505. files);
  2506. return handle_completions_impl(
  2507. ctx_server,
  2508. SERVER_TASK_TYPE_COMPLETION,
  2509. body_parsed,
  2510. files,
  2511. req.should_stop,
  2512. TASK_RESPONSE_TYPE_ANTHROPIC);
  2513. };
  2514. this->post_anthropic_count_tokens = [this](const server_http_req & req) {
  2515. auto res = std::make_unique<server_res_generator>(ctx_server);
  2516. std::vector<raw_buffer> files;
  2517. json body = convert_anthropic_to_oai(json::parse(req.body));
  2518. json body_parsed = oaicompat_chat_params_parse(
  2519. body,
  2520. ctx_server.oai_parser_opt,
  2521. files);
  2522. json prompt = body_parsed.at("prompt");
  2523. llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
  2524. res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
  2525. return res;
  2526. };
  2527. // same with handle_chat_completions, but without inference part
  2528. this->post_apply_template = [this](const server_http_req & req) {
  2529. auto res = std::make_unique<server_res_generator>(ctx_server);
  2530. std::vector<raw_buffer> files; // dummy, unused
  2531. json body = json::parse(req.body);
  2532. json data = oaicompat_chat_params_parse(
  2533. body,
  2534. ctx_server.oai_parser_opt,
  2535. files);
  2536. res->ok({{ "prompt", std::move(data.at("prompt")) }});
  2537. return res;
  2538. };
  2539. this->get_models = [this](const server_http_req &) {
  2540. auto res = std::make_unique<server_res_generator>(ctx_server);
  2541. json model_meta = nullptr;
  2542. if (is_ready()) {
  2543. model_meta = ctx_server.model_meta();
  2544. }
  2545. bool has_mtmd = ctx_server.mctx != nullptr;
  2546. json models = {
  2547. {"models", {
  2548. {
  2549. {"name", params.model_alias.empty() ? params.model.path : params.model_alias},
  2550. {"model", params.model_alias.empty() ? params.model.path : params.model_alias},
  2551. {"modified_at", ""},
  2552. {"size", ""},
  2553. {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
  2554. {"type", "model"},
  2555. {"description", ""},
  2556. {"tags", {""}},
  2557. {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
  2558. {"parameters", ""},
  2559. {"details", {
  2560. {"parent_model", ""},
  2561. {"format", "gguf"},
  2562. {"family", ""},
  2563. {"families", {""}},
  2564. {"parameter_size", ""},
  2565. {"quantization_level", ""}
  2566. }}
  2567. }
  2568. }},
  2569. {"object", "list"},
  2570. {"data", {
  2571. {
  2572. {"id", params.model_alias.empty() ? params.model.path : params.model_alias},
  2573. {"object", "model"},
  2574. {"created", std::time(0)},
  2575. {"owned_by", "llamacpp"},
  2576. {"meta", model_meta},
  2577. },
  2578. }}
  2579. };
  2580. res->ok(models);
  2581. return res;
  2582. };
  2583. this->post_tokenize = [this](const server_http_req & req) {
  2584. auto res = std::make_unique<server_res_generator>(ctx_server);
  2585. const json body = json::parse(req.body);
  2586. json tokens_response = json::array();
  2587. if (body.count("content") != 0) {
  2588. const bool add_special = json_value(body, "add_special", false);
  2589. const bool parse_special = json_value(body, "parse_special", true);
  2590. const bool with_pieces = json_value(body, "with_pieces", false);
  2591. llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special);
  2592. if (with_pieces) {
  2593. for (const auto& token : tokens) {
  2594. std::string piece = common_token_to_piece(ctx_server.ctx, token);
  2595. json piece_json;
  2596. // Check if the piece is valid UTF-8
  2597. if (is_valid_utf8(piece)) {
  2598. piece_json = piece;
  2599. } else {
  2600. // If not valid UTF-8, store as array of byte values
  2601. piece_json = json::array();
  2602. for (unsigned char c : piece) {
  2603. piece_json.push_back(static_cast<int>(c));
  2604. }
  2605. }
  2606. tokens_response.push_back({
  2607. {"id", token},
  2608. {"piece", piece_json}
  2609. });
  2610. }
  2611. } else {
  2612. tokens_response = tokens;
  2613. }
  2614. }
  2615. res->ok(json{{"tokens", std::move(tokens_response)}});
  2616. return res;
  2617. };
  2618. this->post_detokenize = [this](const server_http_req & req) {
  2619. auto res = std::make_unique<server_res_generator>(ctx_server);
  2620. const json body = json::parse(req.body);
  2621. std::string content;
  2622. if (body.count("tokens") != 0) {
  2623. const llama_tokens tokens = body.at("tokens");
  2624. content = tokens_to_str(ctx_server.ctx, tokens);
  2625. }
  2626. res->ok(json{{"content", std::move(content)}});
  2627. return res;
  2628. };
  2629. this->post_embeddings = [this](const server_http_req & req) {
  2630. return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
  2631. };
  2632. this->post_embeddings_oai = [this](const server_http_req & req) {
  2633. return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
  2634. };
  2635. this->post_rerank = [this](const server_http_req & req) {
  2636. auto res = std::make_unique<server_res_generator>(ctx_server);
  2637. if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
  2638. res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
  2639. return res;
  2640. }
  2641. const json body = json::parse(req.body);
  2642. // if true, use TEI API format, otherwise use Jina API format
  2643. // Jina: https://jina.ai/reranker/
  2644. // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
  2645. bool is_tei_format = body.contains("texts");
  2646. json query;
  2647. if (body.count("query") == 1) {
  2648. query = body.at("query");
  2649. if (!query.is_string()) {
  2650. res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
  2651. return res;
  2652. }
  2653. } else {
  2654. res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
  2655. return res;
  2656. }
  2657. std::vector<std::string> documents = json_value(body, "documents",
  2658. json_value(body, "texts", std::vector<std::string>()));
  2659. if (documents.empty()) {
  2660. res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
  2661. return res;
  2662. }
  2663. int top_n = json_value(body, "top_n", (int)documents.size());
  2664. // create and queue the task
  2665. json responses = json::array();
  2666. server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
  2667. {
  2668. std::vector<server_task> tasks;
  2669. tasks.reserve(documents.size());
  2670. for (size_t i = 0; i < documents.size(); i++) {
  2671. auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
  2672. server_task task = server_task(SERVER_TASK_TYPE_RERANK);
  2673. task.id = ctx_server.queue_tasks.get_new_id();
  2674. task.index = i;
  2675. task.tokens = std::move(tmp);
  2676. tasks.push_back(std::move(task));
  2677. }
  2678. rd.post_tasks(std::move(tasks));
  2679. }
  2680. // wait for the results
  2681. auto all_results = rd.wait_for_all(req.should_stop);
  2682. // collect results
  2683. if (all_results.is_terminated) {
  2684. return res; // connection is closed
  2685. } else if (all_results.error) {
  2686. res->error(all_results.error->to_json());
  2687. return res;
  2688. } else {
  2689. for (auto & res : all_results.results) {
  2690. GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
  2691. responses.push_back(res->to_json());
  2692. }
  2693. }
  2694. // write JSON response
  2695. json root = format_response_rerank(
  2696. body,
  2697. responses,
  2698. is_tei_format,
  2699. documents,
  2700. top_n);
  2701. res->ok(root);
  2702. return res;
  2703. };
  2704. this->get_lora_adapters = [this](const server_http_req &) {
  2705. auto res = std::make_unique<server_res_generator>(ctx_server);
  2706. json result = json::array();
  2707. const auto & loras = ctx_server.params_base.lora_adapters;
  2708. for (size_t i = 0; i < loras.size(); ++i) {
  2709. auto & lora = loras[i];
  2710. json entry = {
  2711. {"id", i},
  2712. {"path", lora.path},
  2713. {"scale", lora.scale},
  2714. {"task_name", lora.task_name},
  2715. {"prompt_prefix", lora.prompt_prefix},
  2716. };
  2717. std::string alora_invocation_string = "";
  2718. const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
  2719. std::vector<llama_token> alora_invocation_tokens;
  2720. if (n_alora_tokens) {
  2721. const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
  2722. for (uint64_t i = 0; i < n_alora_tokens; ++i) {
  2723. alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
  2724. alora_invocation_tokens.push_back(alora_tokens[i]);
  2725. }
  2726. entry["alora_invocation_string"] = alora_invocation_string;
  2727. entry["alora_invocation_tokens"] = alora_invocation_tokens;
  2728. }
  2729. result.push_back(std::move(entry));
  2730. }
  2731. res->ok(result);
  2732. return res;
  2733. };
  2734. this->post_lora_adapters = [this](const server_http_req & req) {
  2735. auto res = std::make_unique<server_res_generator>(ctx_server);
  2736. const json body = json::parse(req.body);
  2737. if (!body.is_array()) {
  2738. res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
  2739. return res;
  2740. }
  2741. int task_id = ctx_server.queue_tasks.get_new_id();
  2742. {
  2743. server_task task(SERVER_TASK_TYPE_SET_LORA);
  2744. task.id = task_id;
  2745. task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
  2746. ctx_server.queue_results.add_waiting_task_id(task_id);
  2747. ctx_server.queue_tasks.post(std::move(task));
  2748. }
  2749. // get the result
  2750. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2751. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2752. if (result->is_error()) {
  2753. res->error(result->to_json());
  2754. return res;
  2755. }
  2756. GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
  2757. res->ok(result->to_json());
  2758. return res;
  2759. };
  2760. }
  2761. std::unique_ptr<server_res_generator> server_routes::handle_slots_save(const server_http_req & req, int id_slot) {
  2762. auto res = std::make_unique<server_res_generator>(ctx_server);
  2763. const json request_data = json::parse(req.body);
  2764. std::string filename = request_data.at("filename");
  2765. if (!fs_validate_filename(filename)) {
  2766. res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
  2767. return res;
  2768. }
  2769. std::string filepath = params.slot_save_path + filename;
  2770. int task_id = ctx_server.queue_tasks.get_new_id();
  2771. {
  2772. server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
  2773. task.id = task_id;
  2774. task.slot_action.slot_id = id_slot;
  2775. task.slot_action.filename = filename;
  2776. task.slot_action.filepath = filepath;
  2777. // TODO: use server_response_reader
  2778. ctx_server.queue_results.add_waiting_task_id(task_id);
  2779. ctx_server.queue_tasks.post(std::move(task));
  2780. }
  2781. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2782. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2783. if (result->is_error()) {
  2784. res->error(result->to_json());
  2785. return res;
  2786. }
  2787. res->ok(result->to_json());
  2788. return res;
  2789. }
  2790. std::unique_ptr<server_res_generator> server_routes::handle_slots_restore(const server_http_req & req, int id_slot) {
  2791. auto res = std::make_unique<server_res_generator>(ctx_server);
  2792. const json request_data = json::parse(req.body);
  2793. std::string filename = request_data.at("filename");
  2794. if (!fs_validate_filename(filename)) {
  2795. res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
  2796. return res;
  2797. }
  2798. std::string filepath = params.slot_save_path + filename;
  2799. int task_id = ctx_server.queue_tasks.get_new_id();
  2800. {
  2801. server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
  2802. task.id = task_id;
  2803. task.slot_action.slot_id = id_slot;
  2804. task.slot_action.filename = filename;
  2805. task.slot_action.filepath = filepath;
  2806. // TODO: use server_response_reader
  2807. ctx_server.queue_results.add_waiting_task_id(task_id);
  2808. ctx_server.queue_tasks.post(std::move(task));
  2809. }
  2810. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2811. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2812. if (result->is_error()) {
  2813. res->error(result->to_json());
  2814. return res;
  2815. }
  2816. GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
  2817. res->ok(result->to_json());
  2818. return res;
  2819. }
  2820. std::unique_ptr<server_res_generator> server_routes::handle_slots_erase(const server_http_req &, int id_slot) {
  2821. auto res = std::make_unique<server_res_generator>(ctx_server);
  2822. int task_id = ctx_server.queue_tasks.get_new_id();
  2823. {
  2824. server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
  2825. task.id = task_id;
  2826. task.slot_action.slot_id = id_slot;
  2827. // TODO: use server_response_reader
  2828. ctx_server.queue_results.add_waiting_task_id(task_id);
  2829. ctx_server.queue_tasks.post(std::move(task));
  2830. }
  2831. server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
  2832. ctx_server.queue_results.remove_waiting_task_id(task_id);
  2833. if (result->is_error()) {
  2834. res->error(result->to_json());
  2835. return res;
  2836. }
  2837. GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
  2838. res->ok(result->to_json());
  2839. return res;
  2840. }
  2841. std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
  2842. auto res = std::make_unique<server_res_generator>(ctx_server);
  2843. if (!ctx_server.params_base.embedding) {
  2844. res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
  2845. return res;
  2846. }
  2847. if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
  2848. res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
  2849. return res;
  2850. }
  2851. const json body = json::parse(req.body);
  2852. // for the shape of input/content, see tokenize_input_prompts()
  2853. json prompt;
  2854. if (body.count("input") != 0) {
  2855. prompt = body.at("input");
  2856. } else if (body.contains("content")) {
  2857. res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
  2858. prompt = body.at("content");
  2859. } else {
  2860. res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
  2861. return res;
  2862. }
  2863. bool use_base64 = false;
  2864. if (body.count("encoding_format") != 0) {
  2865. const std::string& format = body.at("encoding_format");
  2866. if (format == "base64") {
  2867. use_base64 = true;
  2868. } else if (format != "float") {
  2869. res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
  2870. return res;
  2871. }
  2872. }
  2873. auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  2874. for (const auto & tokens : tokenized_prompts) {
  2875. // this check is necessary for models that do not add BOS token to the input
  2876. if (tokens.empty()) {
  2877. res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
  2878. return res;
  2879. }
  2880. }
  2881. int embd_normalize = 2; // default to Euclidean/L2 norm
  2882. if (body.count("embd_normalize") != 0) {
  2883. embd_normalize = body.at("embd_normalize");
  2884. if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
  2885. SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
  2886. }
  2887. }
  2888. // create and queue the task
  2889. json responses = json::array();
  2890. server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
  2891. {
  2892. std::vector<server_task> tasks;
  2893. for (size_t i = 0; i < tokenized_prompts.size(); i++) {
  2894. server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
  2895. task.id = ctx_server.queue_tasks.get_new_id();
  2896. task.index = i;
  2897. task.tokens = std::move(tokenized_prompts[i]);
  2898. // OAI-compat
  2899. task.params.res_type = res_type;
  2900. task.params.embd_normalize = embd_normalize;
  2901. tasks.push_back(std::move(task));
  2902. }
  2903. rd.post_tasks(std::move(tasks));
  2904. }
  2905. // wait for the results
  2906. auto all_results = rd.wait_for_all(req.should_stop);
  2907. // collect results
  2908. if (all_results.is_terminated) {
  2909. return res; // connection is closed
  2910. } else if (all_results.error) {
  2911. res->error(all_results.error->to_json());
  2912. return res;
  2913. } else {
  2914. for (auto & res : all_results.results) {
  2915. GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
  2916. responses.push_back(res->to_json());
  2917. }
  2918. }
  2919. // write JSON response
  2920. json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
  2921. ? format_embeddings_response_oaicompat(body, responses, use_base64)
  2922. : json(responses);
  2923. res->ok(root);
  2924. return res;
  2925. }