perplexity.cpp 79 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061
  1. #include "common.h"
  2. #include "llama.h"
  3. #include <cmath>
  4. #include <cstdio>
  5. #include <cstring>
  6. #include <ctime>
  7. #include <sstream>
  8. #include <thread>
  9. #include <mutex>
  10. #include <atomic>
  11. #include <vector>
  12. #include <array>
  13. #include <fstream>
  14. #include <sstream>
  15. #if defined(_MSC_VER)
  16. #pragma warning(disable: 4244 4267) // possible loss of data
  17. #endif
  18. struct results_perplexity {
  19. std::vector<llama_token> tokens;
  20. double ppl_value;
  21. std::vector<float> logits;
  22. std::vector<float> probs;
  23. };
  24. struct results_log_softmax {
  25. double log_softmax;
  26. float logit;
  27. float prob;
  28. };
  29. static void write_logfile(
  30. const llama_context * ctx, const gpt_params & params, const llama_model * model,
  31. const struct results_perplexity & results
  32. ) {
  33. if (params.logdir.empty()) {
  34. return;
  35. }
  36. if (params.hellaswag) {
  37. fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
  38. return;
  39. }
  40. const std::string timestamp = string_get_sortable_timestamp();
  41. const bool success = fs_create_directory_with_parents(params.logdir);
  42. if (!success) {
  43. fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
  44. __func__, params.logdir.c_str());
  45. return;
  46. }
  47. const std::string logfile_path = params.logdir + timestamp + ".yml";
  48. FILE * logfile = fopen(logfile_path.c_str(), "w");
  49. if (logfile == NULL) {
  50. fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
  51. return;
  52. }
  53. fprintf(logfile, "binary: main\n");
  54. char model_desc[128];
  55. llama_model_desc(model, model_desc, sizeof(model_desc));
  56. yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc);
  57. fprintf(logfile, "\n");
  58. fprintf(logfile, "######################\n");
  59. fprintf(logfile, "# Perplexity Results #\n");
  60. fprintf(logfile, "######################\n");
  61. fprintf(logfile, "\n");
  62. yaml_dump_vector_float(logfile, "logits", results.logits);
  63. fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
  64. yaml_dump_vector_float(logfile, "probs", results.probs);
  65. llama_perf_dump_yaml(logfile, ctx);
  66. fclose(logfile);
  67. }
  68. static std::vector<float> softmax(const std::vector<float>& logits) {
  69. std::vector<float> probs(logits.size());
  70. float max_logit = logits[0];
  71. for (float v : logits) {
  72. max_logit = std::max(max_logit, v);
  73. }
  74. double sum_exp = 0.0;
  75. for (size_t i = 0; i < logits.size(); i++) {
  76. // Subtract the maximum logit value from the current logit value for numerical stability
  77. const float logit = logits[i] - max_logit;
  78. const float exp_logit = expf(logit);
  79. sum_exp += exp_logit;
  80. probs[i] = exp_logit;
  81. }
  82. for (size_t i = 0; i < probs.size(); i++) {
  83. probs[i] /= sum_exp;
  84. }
  85. return probs;
  86. }
  87. static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
  88. float max_logit = logits[0];
  89. for (int i = 1; i < n_vocab; ++i) {
  90. max_logit = std::max(max_logit, logits[i]);
  91. }
  92. double sum_exp = 0.0;
  93. for (int i = 0; i < n_vocab; ++i) {
  94. sum_exp += expf(logits[i] - max_logit);
  95. }
  96. return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
  97. }
  98. static inline int nearest_int(float fval) {
  99. //assert(fval <= 4194303.f);
  100. float val = fval + 12582912.f;
  101. int i; memcpy(&i, &val, sizeof(int));
  102. return (i & 0x007fffff) - 0x00400000;
  103. }
  104. static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
  105. float max_logit = logits[0];
  106. float min_logit = logits[0];
  107. for (int i = 1; i < n_vocab; ++i) {
  108. max_logit = std::max(max_logit, logits[i]);
  109. min_logit = std::min(min_logit, logits[i]);
  110. }
  111. min_logit = std::max(min_logit, max_logit - 16);
  112. double sum_exp = 0.0;
  113. for (int i = 0; i < n_vocab; ++i) {
  114. sum_exp += expf(logits[i] - max_logit);
  115. }
  116. const float log_sum_exp = log(sum_exp);
  117. const float min_log_prob = min_logit - max_logit - log_sum_exp;
  118. const float scale = (max_logit - min_logit)/65535.f;
  119. float * d = (float *)log_prob;
  120. d[0] = scale;
  121. d[1] = min_log_prob;
  122. log_prob += 4;
  123. if (scale) {
  124. const float inv_scale = 1/scale;
  125. for (int i = 0; i < n_vocab; ++i) {
  126. log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
  127. }
  128. } else {
  129. std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
  130. }
  131. return max_logit + log_sum_exp - logits[tok];
  132. }
  133. static void process_logits(
  134. int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
  135. double & nll, double & nll2, float * logit_history, float * prob_history
  136. ) {
  137. std::mutex mutex;
  138. int counter = 0;
  139. auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
  140. double local_nll = 0;
  141. double local_nll2 = 0;
  142. while (true) {
  143. std::unique_lock<std::mutex> lock(mutex);
  144. int i = counter++;
  145. if (i >= n_token) {
  146. nll += local_nll; nll2 += local_nll2;
  147. break;
  148. }
  149. lock.unlock();
  150. const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
  151. const double v = -results.log_softmax;
  152. local_nll += v;
  153. local_nll2 += v*v;
  154. logit_history[i] = results.logit;
  155. prob_history[i] = results.prob;
  156. }
  157. };
  158. for (auto & w : workers) {
  159. w = std::thread(compute);
  160. }
  161. compute();
  162. for (auto & w : workers) {
  163. w.join();
  164. }
  165. }
  166. static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
  167. std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
  168. std::mutex mutex;
  169. const int nv = 2*((n_vocab + 1)/2) + 4;
  170. int counter = 0;
  171. auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
  172. double local_nll = 0;
  173. double local_nll2 = 0;
  174. while (true) {
  175. std::unique_lock<std::mutex> lock(mutex);
  176. int i = counter++;
  177. if (i >= n_token) {
  178. nll += local_nll; nll2 += local_nll2;
  179. break;
  180. }
  181. lock.unlock();
  182. const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
  183. local_nll += v;
  184. local_nll2 += v*v;
  185. }
  186. };
  187. for (auto & w : workers) {
  188. w = std::thread(compute);
  189. }
  190. compute();
  191. for (auto & w : workers) {
  192. w.join();
  193. }
  194. out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
  195. }
  196. struct kl_divergence_result {
  197. double sum_nll = 0.0;
  198. double sum_nll2 = 0.0;
  199. double sum_nll_base = 0.0;
  200. double sum_nll_base2 = 0.0;
  201. double sum_nll_nll_base = 0.0;
  202. double sum_kld = 0.0;
  203. double sum_kld2 = 0.0;
  204. double sum_p_diff = 0.0;
  205. double sum_p_diff2 = 0.0;
  206. double sum_p_diff4 = 0.0;
  207. float max_p_diff = 0.0f;
  208. size_t n_same_top = 0.0;
  209. size_t count = 0.0;
  210. };
  211. static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
  212. float max_logit = logits[0];
  213. int imax = 0;
  214. for (int i = 1; i < n_vocab; ++i) {
  215. if (logits[i] > max_logit) {
  216. max_logit = logits[i];
  217. imax = i;
  218. }
  219. }
  220. double sum_exp = 0.0;
  221. for (int i = 0; i < n_vocab; ++i) {
  222. sum_exp += expf(logits[i] - max_logit);
  223. }
  224. const float log_sum_exp = log(sum_exp);
  225. const float * d = (const float *)base_log_prob;
  226. const float scale = d[0];
  227. const float min_log_prob = d[1];
  228. base_log_prob += 4;
  229. const float nll = max_logit + log_sum_exp - logits[tok];
  230. kld.sum_nll += nll;
  231. kld.sum_nll2 += nll*nll;
  232. const float nll_base = -(scale*base_log_prob[tok] + min_log_prob);
  233. kld.sum_nll_base += nll_base;
  234. kld.sum_nll_base2 += nll_base*nll_base;
  235. kld.sum_nll_nll_base += nll*nll_base;
  236. max_logit += log_sum_exp;
  237. double sum = 0;
  238. int imax_base = -1;
  239. float p_log_base_max = 0;
  240. for (int i = 0; i < n_vocab; ++i) {
  241. const float p_log_base = scale*base_log_prob[i] + min_log_prob;
  242. if (i == 0 || p_log_base > p_log_base_max) {
  243. p_log_base_max = p_log_base;
  244. imax_base = i;
  245. }
  246. if (p_log_base > -16.f) {
  247. const float p_base = expf(p_log_base);
  248. sum += p_base * (p_log_base - logits[i] + max_logit);
  249. }
  250. }
  251. kld.sum_kld += sum;
  252. kld.sum_kld2 += sum*sum;
  253. ++kld.count;
  254. if (imax == imax_base) ++kld.n_same_top;
  255. const float p_base = expf(-nll_base);
  256. const float p = expf(-nll);
  257. const float p_diff = p - p_base;
  258. kld.sum_p_diff += p_diff;
  259. const double p_diff2 = p_diff*p_diff;
  260. kld.sum_p_diff2 += p_diff2;
  261. kld.sum_p_diff4 += p_diff2*p_diff2;
  262. kld.max_p_diff = std::max(kld.max_p_diff, std::fabs(p_diff));
  263. return std::make_pair(sum, p_diff);
  264. }
  265. static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
  266. std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
  267. float * kld_values, float * p_diff_values) {
  268. std::mutex mutex;
  269. const int nv = 2*((n_vocab + 1)/2) + 4;
  270. int counter = 0;
  271. auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () {
  272. kl_divergence_result local_kld;
  273. while (true) {
  274. std::unique_lock<std::mutex> lock(mutex);
  275. int i = counter++;
  276. if (i >= n_token) {
  277. kld.sum_nll += local_kld.sum_nll;
  278. kld.sum_nll2 += local_kld.sum_nll2;
  279. kld.sum_nll_base += local_kld.sum_nll_base;
  280. kld.sum_nll_base2 += local_kld.sum_nll_base2;
  281. kld.sum_nll_nll_base += local_kld.sum_nll_nll_base;
  282. kld.sum_kld += local_kld.sum_kld;
  283. kld.sum_kld2 += local_kld.sum_kld2;
  284. kld.sum_p_diff += local_kld.sum_p_diff;
  285. kld.sum_p_diff2 += local_kld.sum_p_diff2;
  286. kld.sum_p_diff4 += local_kld.sum_p_diff4;
  287. kld.n_same_top += local_kld.n_same_top;
  288. kld.max_p_diff = std::max(kld.max_p_diff, local_kld.max_p_diff);
  289. kld.count += local_kld.count;
  290. break;
  291. }
  292. lock.unlock();
  293. std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
  294. kld_values[i] = (float)v.first;
  295. p_diff_values[i] = v.second;
  296. }
  297. };
  298. for (auto & w : workers) {
  299. w = std::thread(compute);
  300. }
  301. compute();
  302. for (auto & w : workers) {
  303. w.join();
  304. }
  305. }
  306. static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
  307. // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
  308. // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
  309. // Output: `perplexity: 13.5106 [114/114]`
  310. // BOS tokens will be added for each chunk before eval
  311. const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
  312. GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
  313. fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
  314. std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
  315. const int n_ctx = llama_n_ctx(ctx);
  316. if (int(tokens.size()) < 2*n_ctx) {
  317. fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
  318. n_ctx);
  319. fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
  320. return {std::move(tokens), 0., {}, {}};
  321. }
  322. std::vector<float> logit_history;
  323. std::vector<float> prob_history;
  324. logit_history.resize(tokens.size());
  325. prob_history.resize(tokens.size());
  326. if (params.ppl_stride <= 0) {
  327. fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
  328. return {tokens, -1, logit_history, prob_history};
  329. }
  330. const int calc_chunk = n_ctx;
  331. fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
  332. if (int(tokens.size()) <= calc_chunk) {
  333. fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
  334. tokens.size(), n_ctx, params.ppl_stride);
  335. return {tokens, -1, logit_history, prob_history};
  336. }
  337. const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
  338. const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
  339. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  340. const int n_batch = params.n_batch;
  341. int count = 0;
  342. double nll = 0.0;
  343. fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
  344. for (int i = 0; i < n_chunk; ++i) {
  345. const int start = i * params.ppl_stride;
  346. const int end = start + calc_chunk;
  347. const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
  348. //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
  349. std::vector<float> logits;
  350. const auto t_start = std::chrono::high_resolution_clock::now();
  351. // clear the KV cache
  352. llama_kv_cache_clear(ctx);
  353. for (int j = 0; j < num_batches; ++j) {
  354. const int batch_start = start + j * n_batch;
  355. const int batch_size = std::min(end - batch_start, n_batch);
  356. //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
  357. // TODO: use llama_batch.logits instead of relying on logits_all == true
  358. if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
  359. //fprintf(stderr, "%s : failed to eval\n", __func__);
  360. return {tokens, -1, logit_history, prob_history};
  361. }
  362. // save original token and restore it after eval
  363. const auto token_org = tokens[batch_start];
  364. // add BOS token for the first batch of each chunk
  365. if (add_bos && j == 0) {
  366. tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
  367. }
  368. const auto batch_logits = llama_get_logits(ctx);
  369. logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
  370. if (j == 0) {
  371. tokens[batch_start] = token_org;
  372. }
  373. }
  374. const auto t_end = std::chrono::high_resolution_clock::now();
  375. if (i == 0) {
  376. const float t_total = std::chrono::duration<float>(t_end - t_start).count();
  377. fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
  378. int total_seconds = (int)(t_total * n_chunk);
  379. if (total_seconds >= 60*60) {
  380. fprintf(stderr, "%d hours ", total_seconds / (60*60));
  381. total_seconds = total_seconds % (60*60);
  382. }
  383. fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
  384. }
  385. //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
  386. for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
  387. // Calculate probability of next token, given the previous ones.
  388. const std::vector<float> tok_logits(
  389. logits.begin() + (j + 0) * n_vocab,
  390. logits.begin() + (j + 1) * n_vocab);
  391. const float prob = softmax(tok_logits)[tokens[start + j + 1]];
  392. logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
  393. prob_history[start + j + 1] = prob;
  394. nll += -std::log(prob);
  395. ++count;
  396. }
  397. // perplexity is e^(average negative log-likelihood)
  398. if (params.ppl_output_type == 0) {
  399. printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
  400. } else {
  401. printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
  402. }
  403. fflush(stdout);
  404. }
  405. printf("\n");
  406. return {tokens, std::exp(nll / count), logit_history, prob_history};
  407. }
  408. static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
  409. if (params.ppl_stride > 0) {
  410. return perplexity_v2(ctx, params);
  411. }
  412. // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
  413. // Run `./llama-perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
  414. // Output: `perplexity: 13.5106 [114/114]`
  415. // BOS tokens will be added for each chunk before eval
  416. const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
  417. GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
  418. std::ofstream logits_stream;
  419. if (!params.logits_file.empty()) {
  420. logits_stream.open(params.logits_file.c_str(), std::ios::binary);
  421. if (!logits_stream.is_open()) {
  422. fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
  423. return {};
  424. }
  425. fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
  426. logits_stream.write("_logits_", 8);
  427. logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
  428. }
  429. auto tim1 = std::chrono::high_resolution_clock::now();
  430. fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
  431. std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
  432. auto tim2 = std::chrono::high_resolution_clock::now();
  433. fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
  434. if (int(tokens.size()) < 2*n_ctx) {
  435. fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
  436. n_ctx);
  437. fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
  438. return {std::move(tokens), 0., {}, {}};
  439. }
  440. std::vector<float> logit_history;
  441. logit_history.resize(tokens.size());
  442. std::vector<float> prob_history;
  443. prob_history.resize(tokens.size());
  444. const int n_chunk_max = tokens.size() / n_ctx;
  445. const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
  446. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  447. const int n_batch = params.n_batch;
  448. int count = 0;
  449. double nll = 0.0;
  450. double nll2 = 0.0;
  451. const int num_batches = (n_ctx + n_batch - 1) / n_batch;
  452. const int n_seq = std::max(1, n_batch / n_ctx);
  453. GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
  454. GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
  455. llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
  456. std::vector<float> logits;
  457. if (num_batches > 1) {
  458. logits.reserve((size_t)n_ctx * n_vocab);
  459. }
  460. fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
  461. std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
  462. std::vector<uint16_t> log_probs;
  463. if (!params.logits_file.empty()) {
  464. logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
  465. logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
  466. logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
  467. const int nv = 2*((n_vocab + 1)/2) + 4;
  468. log_probs.resize(n_ctx * nv);
  469. }
  470. // We get the logits for all the tokens in the context window (params.n_ctx)
  471. // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
  472. // calculate the perplexity over the last half of the window (so the model always has
  473. // some context to predict the token).
  474. //
  475. // We rely on the fact that attention in the forward pass only looks at previous
  476. // tokens here, so the logits returned for each token are an accurate representation
  477. // of what the model would have predicted at that point.
  478. //
  479. // Example, we have a context window of 512, we will compute perplexity for each of the
  480. // last 256 tokens. Then, we split the input up into context window size chunks to
  481. // process the entire prompt.
  482. const int first = n_ctx/2;
  483. for (int i = 0; i < n_chunk; i += n_seq) {
  484. const int start = i * n_ctx;
  485. const int end = start + n_ctx;
  486. const int n_seq_batch = std::min(n_seq, n_chunk - i);
  487. const auto t_start = std::chrono::high_resolution_clock::now();
  488. // clear the KV cache
  489. llama_kv_cache_clear(ctx);
  490. for (int j = 0; j < num_batches; ++j) {
  491. const int batch_start = start + j * n_batch;
  492. const int batch_size = std::min(end - batch_start, n_batch);
  493. int n_outputs = 0;
  494. batch.n_tokens = 0;
  495. for (int seq = 0; seq < n_seq_batch; seq++) {
  496. int seq_start = batch_start + seq*n_ctx;
  497. // save original token and restore it after eval
  498. const auto token_org = tokens[seq_start];
  499. // add BOS token for the first batch of each chunk
  500. if (add_bos && j == 0) {
  501. tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
  502. }
  503. for (int k = 0; k < batch_size; ++k) {
  504. const int idx = seq*n_ctx + k;
  505. batch.token [idx] = tokens[seq_start + k];
  506. batch.pos [idx] = j*n_batch + k;
  507. batch.n_seq_id[idx] = 1;
  508. batch.seq_id [idx][0] = seq;
  509. batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
  510. n_outputs += batch.logits[idx] != 0;
  511. }
  512. batch.n_tokens += batch_size;
  513. // restore the original token in case it was set to BOS
  514. tokens[seq_start] = token_org;
  515. }
  516. if (llama_decode(ctx, batch)) {
  517. fprintf(stderr, "%s : failed to eval\n", __func__);
  518. return {tokens, -1, logit_history, prob_history};
  519. }
  520. if (num_batches > 1 && n_outputs > 0) {
  521. const auto * batch_logits = llama_get_logits(ctx);
  522. logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
  523. }
  524. }
  525. if (i == 0) {
  526. llama_synchronize(ctx);
  527. const auto t_end = std::chrono::high_resolution_clock::now();
  528. const float t_total = std::chrono::duration<float>(t_end - t_start).count();
  529. fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
  530. int total_seconds = (int)(t_total*n_chunk/n_seq);
  531. if (total_seconds >= 60*60) {
  532. fprintf(stderr, "%d hours ", total_seconds / (60*60));
  533. total_seconds = total_seconds % (60*60);
  534. }
  535. fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
  536. }
  537. for (int seq = 0; seq < n_seq_batch; seq++) {
  538. const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
  539. llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
  540. if (!params.logits_file.empty()) {
  541. process_logits(logits_stream, n_vocab, all_logits,
  542. tokens_data, n_ctx - 1 - first,
  543. workers, log_probs, nll, nll2);
  544. } else {
  545. process_logits(n_vocab, all_logits,
  546. tokens_data, n_ctx - 1 - first,
  547. workers, nll, nll2,
  548. logit_history.data() + start + seq*n_ctx + first,
  549. prob_history.data() + start + seq*n_ctx + first);
  550. }
  551. count += n_ctx - first - 1;
  552. // perplexity is e^(average negative log-likelihood)
  553. if (params.ppl_output_type == 0) {
  554. printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
  555. } else {
  556. double av = nll/count;
  557. double av2 = nll2/count - av*av;
  558. if (av2 > 0) av2 = sqrt(av2/(count-1));
  559. printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
  560. }
  561. }
  562. fflush(stdout);
  563. logits.clear();
  564. }
  565. printf("\n");
  566. nll2 /= count;
  567. nll /= count;
  568. const double ppl = exp(nll);
  569. nll2 -= nll * nll;
  570. if (nll2 > 0) {
  571. nll2 = sqrt(nll2/(count-1));
  572. printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
  573. } else {
  574. printf("Unexpected negative standard deviation of log(prob)\n");
  575. }
  576. llama_batch_free(batch);
  577. return {tokens, ppl, logit_history, prob_history};
  578. }
  579. static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
  580. int prev_outputs = 0;
  581. for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
  582. const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
  583. llama_batch batch_view = {
  584. n_tokens,
  585. batch.token + i,
  586. nullptr,
  587. batch.pos + i,
  588. batch.n_seq_id + i,
  589. batch.seq_id + i,
  590. batch.logits + i,
  591. 0, 0, 0, // unused
  592. };
  593. const int ret = llama_decode(ctx, batch_view);
  594. if (ret != 0) {
  595. LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
  596. return false;
  597. }
  598. int n_outputs = 0;
  599. for (int i = 0; i < n_tokens; ++i) {
  600. n_outputs += batch_view.logits[i] != 0;
  601. }
  602. memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
  603. prev_outputs += n_outputs;
  604. }
  605. return true;
  606. }
  607. #define K_TOKEN_CHUNK 4
  608. static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
  609. const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
  610. if (eval_results.size() != eval_pairs.size()) {
  611. eval_results.resize(eval_pairs.size());
  612. }
  613. if (eval_pairs.empty()) return;
  614. size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
  615. std::atomic<int> counter(0);
  616. auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
  617. float local_logprobs[K_TOKEN_CHUNK];
  618. while (true) {
  619. size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
  620. if (first >= eval_results.size()) break;
  621. size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
  622. for (size_t i = first; i < last; ++i) {
  623. auto logits = batch_logits + eval_pairs[i].first * n_vocab;
  624. float max_logit = logits[0];
  625. for (int j = 1; j < n_vocab; ++j) {
  626. max_logit = std::max(max_logit, logits[j]);
  627. }
  628. float sum_p = 0.f;
  629. for (int j = 0; j < n_vocab; ++j) {
  630. sum_p += expf(logits[j] - max_logit);
  631. }
  632. local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
  633. }
  634. std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
  635. }
  636. };
  637. for (size_t it = 0; it < max_threads; ++it) {
  638. workers[it] = std::thread(compute);
  639. }
  640. for (size_t it = 0; it < max_threads; ++it) {
  641. workers[it].join();
  642. }
  643. }
  644. static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
  645. // Calculates hellaswag score (acc_norm) from prompt
  646. //
  647. // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
  648. // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
  649. //
  650. // All 10042 tasks should be extracted to keep the results standardized like other implementations.
  651. //
  652. // Datafile layout:
  653. // ['??'] denotes json fields
  654. // 6 lines per task:
  655. // ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
  656. // ['label'] - The index the best common sense ending aka gold ending
  657. // ['endings'][0] - Endings added to the first part of the query
  658. // ['endings'][1]
  659. // ['endings'][2]
  660. // ['endings'][3]
  661. std::vector<std::string> prompt_lines;
  662. std::istringstream strstream(params.prompt);
  663. std::string line;
  664. while (std::getline(strstream,line,'\n')) {
  665. prompt_lines.push_back(line);
  666. }
  667. if (prompt_lines.size() % 6 != 0) {
  668. fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
  669. return;
  670. }
  671. size_t hs_task_count = prompt_lines.size()/6;
  672. fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
  673. const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
  674. fprintf(stderr, "================================= is_spm = %d\n", is_spm);
  675. // The tasks should be randomized so the score stabilizes quickly.
  676. bool randomize_tasks = true;
  677. // Number of tasks to use when computing the score
  678. if (params.hellaswag_tasks < hs_task_count) {
  679. hs_task_count = params.hellaswag_tasks;
  680. }
  681. // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
  682. std::mt19937 rng(1);
  683. // Dataholder for hellaswag tasks
  684. struct hs_data_t {
  685. std::string context;
  686. size_t gold_ending_idx;
  687. std::string ending[4];
  688. size_t ending_logprob_count[4];
  689. double ending_logprob[4];
  690. size_t i_logits; // starting index of logits in the llama_batch
  691. size_t common_prefix; // max number of initial tokens that are the same in all sentences
  692. size_t required_tokens; // needed number of tokens to evaluate all 4 endings
  693. std::vector<llama_token> seq_tokens[4];
  694. };
  695. fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
  696. // Select and read data from prompt lines
  697. std::vector<hs_data_t> hs_data(hs_task_count);
  698. for (size_t i = 0; i < hs_task_count; i++) {
  699. size_t idx = i;
  700. auto & hs_cur = hs_data[i];
  701. // Select a random example of those left in the prompt
  702. if (randomize_tasks) {
  703. std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
  704. idx = dist(rng);
  705. }
  706. hs_cur.context = prompt_lines[idx*6];
  707. hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
  708. for (size_t j = 0; j < 4; j++) {
  709. hs_cur.ending[j] = prompt_lines[idx*6+2+j];
  710. hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
  711. }
  712. // determine the common prefix of the endings
  713. hs_cur.common_prefix = 0;
  714. for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
  715. if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
  716. hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
  717. hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
  718. break;
  719. }
  720. hs_cur.common_prefix++;
  721. }
  722. hs_cur.required_tokens = hs_cur.common_prefix +
  723. hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
  724. hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
  725. hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
  726. hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
  727. //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
  728. // Delete the selected random example from the prompt
  729. if (randomize_tasks) {
  730. prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
  731. }
  732. }
  733. fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
  734. printf("\ntask\tacc_norm\n");
  735. double acc = 0.0f;
  736. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  737. const int n_ctx = llama_n_ctx(ctx);
  738. const int n_batch = params.n_batch;
  739. const int max_tasks_per_batch = 32;
  740. const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
  741. llama_batch batch = llama_batch_init(n_ctx, 0, 4);
  742. std::vector<float> tok_logits(n_vocab);
  743. // TODO: this could be made smaller; it's currently the worst-case size
  744. std::vector<float> batch_logits(n_vocab*n_ctx);
  745. std::vector<std::pair<size_t, llama_token>> eval_pairs;
  746. std::vector<float> eval_results;
  747. std::vector<std::thread> workers(std::thread::hardware_concurrency());
  748. for (size_t i0 = 0; i0 < hs_task_count; i0++) {
  749. int n_cur = 0;
  750. size_t i1 = i0;
  751. size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
  752. llama_batch_clear(batch);
  753. // batch as much tasks as possible into the available context
  754. // each task has 4 unique sequence ids - one for each ending
  755. // the common prefix is shared among the 4 sequences to save tokens
  756. // we extract logits only from the last common token and from all ending tokens of each sequence
  757. while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
  758. auto & hs_cur = hs_data[i1];
  759. int n_logits = 0;
  760. const int s0 = 4*(i1 - i0);
  761. if (s0 + 4 > max_seq) {
  762. break;
  763. }
  764. for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
  765. llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
  766. }
  767. batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
  768. n_logits += 1;
  769. for (int s = 0; s < 4; ++s) {
  770. const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
  771. // TODO: don't evaluate the last token of each sequence
  772. for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
  773. const bool needs_logits = i < seq_tokens_size - 1;
  774. llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
  775. n_logits += needs_logits;
  776. }
  777. }
  778. hs_cur.i_logits = i_logits;
  779. i_logits += n_logits;
  780. n_cur += hs_data[i1].required_tokens;
  781. if (++i1 == hs_task_count) {
  782. break;
  783. }
  784. }
  785. if (i0 == i1) {
  786. fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
  787. return;
  788. }
  789. llama_kv_cache_clear(ctx);
  790. // decode all tasks [i0, i1)
  791. if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
  792. fprintf(stderr, "%s: llama_decode() failed\n", __func__);
  793. return;
  794. }
  795. // Compute log-probs in parallel
  796. // First we collect all tasks
  797. eval_pairs.clear();
  798. for (size_t i = i0; i < i1; ++i) {
  799. auto & hs_cur = hs_data[i];
  800. size_t li = 1; // skip the last logit of the common prefix (computed separately below)
  801. for (int s = 0; s < 4; ++s) {
  802. for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
  803. eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
  804. }
  805. }
  806. }
  807. // Then we do the actual calculation
  808. compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
  809. size_t ir = 0;
  810. // compute the logprobs for each ending of the decoded tasks
  811. for (size_t i = i0; i < i1; ++i) {
  812. auto & hs_cur = hs_data[i];
  813. // get the logits of the last token of the common prefix
  814. std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
  815. const auto first_probs = softmax(tok_logits);
  816. for (int s = 0; s < 4; ++s) {
  817. hs_cur.ending_logprob_count[s] = 1;
  818. hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
  819. for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
  820. hs_cur.ending_logprob[s] += eval_results[ir++];
  821. hs_cur.ending_logprob_count[s]++;
  822. }
  823. hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
  824. }
  825. // Find the ending with maximum logprob
  826. size_t ending_logprob_max_idx = 0;
  827. double ending_logprob_max_val = hs_cur.ending_logprob[0];
  828. for (size_t s = 1; s < 4; s++) {
  829. if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
  830. ending_logprob_max_idx = s;
  831. ending_logprob_max_val = hs_cur.ending_logprob[s];
  832. }
  833. }
  834. //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
  835. // If the gold ending got the maximum logprobe add one accuracy point
  836. if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
  837. acc += 1.0;
  838. }
  839. // Print the accumulated accuracy mean x 100
  840. printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
  841. fflush(stdout);
  842. }
  843. i0 = i1 - 1;
  844. }
  845. llama_batch_free(batch);
  846. printf("\n");
  847. }
  848. struct winogrande_entry {
  849. std::string first;
  850. std::string second;
  851. std::array<std::string, 2> choices;
  852. int answer;
  853. size_t i_logits;
  854. size_t common_prefix;
  855. size_t required_tokens;
  856. size_t n_base1; // number of tokens for context + choice 1
  857. size_t n_base2; // number of tokens for context + choice 2
  858. std::vector<llama_token> seq_tokens[2];
  859. };
  860. static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string & prompt) {
  861. std::vector<winogrande_entry> result;
  862. std::istringstream in(prompt);
  863. std::string line;
  864. std::array<int, 4> comma_pos;
  865. while (true) {
  866. std::getline(in, line);
  867. if (in.fail() || in.eof()) break;
  868. int ipos = 0;
  869. bool quote_open = false;
  870. for (int i = 0; i < int(line.size()); ++i) {
  871. if (!quote_open) {
  872. if (line[i] == ',') {
  873. comma_pos[ipos++] = i;
  874. if (ipos == 4) break;
  875. }
  876. else if (line[i] == '"') {
  877. quote_open = true;
  878. }
  879. }
  880. else {
  881. if (line[i] == '"') {
  882. quote_open = false;
  883. }
  884. }
  885. }
  886. if (ipos != 4) {
  887. printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
  888. continue;
  889. }
  890. auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
  891. : line.substr(comma_pos[0]+1, comma_pos[1] - comma_pos[0] - 1);
  892. auto choice1 = line.substr(comma_pos[1]+1, comma_pos[2] - comma_pos[1] - 1);
  893. auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
  894. auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
  895. auto index = line.substr(0, comma_pos[0]);
  896. int where = 0;
  897. for ( ; where < int(sentence.size()); ++where) {
  898. if (sentence[where] == '_') break;
  899. }
  900. if (where == int(sentence.size())) {
  901. printf("%s: no _ in <%s>\n", __func__, sentence.c_str());
  902. continue;
  903. }
  904. std::istringstream stream(answer.c_str());
  905. int i_answer; stream >> i_answer;
  906. if (stream.fail() || i_answer < 1 || i_answer > 2) {
  907. printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
  908. continue;
  909. }
  910. result.emplace_back();
  911. auto& wg = result.back();
  912. wg.first = sentence.substr(0, where);
  913. wg.second = sentence.substr(where + 1, sentence.size() - where - 1);
  914. wg.choices[0] = std::move(choice1);
  915. wg.choices[1] = std::move(choice2);
  916. wg.answer = i_answer;
  917. }
  918. return result;
  919. }
  920. /*
  921. * Evaluates the Winogrande score.
  922. * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2)
  923. * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp
  924. * As an example, the 1st row in the above dataset is
  925. *
  926. * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
  927. *
  928. */
  929. static void winogrande_score(llama_context * ctx, const gpt_params & params) {
  930. constexpr int k_min_trailing_ctx = 3;
  931. auto data = load_winogrande_from_csv(params.prompt);
  932. if (data.empty()) {
  933. fprintf(stderr, "%s: no tasks\n", __func__);
  934. return;
  935. }
  936. fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size());
  937. if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
  938. fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
  939. std::mt19937 rng(1);
  940. std::vector<int> aux(data.size());
  941. for (int i = 0; i < int(data.size()); ++i) {
  942. aux[i] = i;
  943. }
  944. float scale = 1/(1.f + (float)rng.max());
  945. std::vector<winogrande_entry> selected;
  946. selected.resize(params.winogrande_tasks);
  947. for (int i = 0; i < int(params.winogrande_tasks); ++i) {
  948. int j = int(scale*rng()*aux.size());
  949. selected[i] = std::move(data[aux[j]]);
  950. aux[j] = aux.back();
  951. aux.pop_back();
  952. }
  953. data = std::move(selected);
  954. }
  955. fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
  956. for (auto & task : data) {
  957. task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true);
  958. task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true);
  959. task.common_prefix = 0;
  960. for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
  961. if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
  962. break;
  963. }
  964. task.common_prefix++;
  965. }
  966. // TODO: the last token of each of the sequences don't need to be evaluated
  967. task.required_tokens = task.common_prefix +
  968. task.seq_tokens[0].size() - task.common_prefix +
  969. task.seq_tokens[1].size() - task.common_prefix;
  970. task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size();
  971. task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size();
  972. }
  973. fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
  974. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  975. const int n_ctx = llama_n_ctx(ctx);
  976. const int n_batch = params.n_batch;
  977. const int max_tasks_per_batch = 128;
  978. const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
  979. llama_batch batch = llama_batch_init(n_ctx, 0, 2);
  980. std::vector<float> tok_logits(n_vocab);
  981. // TODO: this could be made smaller; it's currently the worst-case size
  982. std::vector<float> batch_logits(n_vocab*n_ctx);
  983. std::vector<std::pair<size_t, llama_token>> eval_pairs;
  984. std::vector<float> eval_results;
  985. std::vector<std::thread> workers(std::thread::hardware_concurrency());
  986. int n_correct = 0;
  987. int n_done = 0;
  988. for (size_t i0 = 0; i0 < data.size(); i0++) {
  989. int n_cur = 0;
  990. size_t i1 = i0;
  991. size_t i_logits = 0;
  992. llama_batch_clear(batch);
  993. while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
  994. int n_logits = 0;
  995. const int s0 = 2*(i1 - i0);
  996. if (s0 + 2 > max_seq) {
  997. break;
  998. }
  999. for (size_t i = 0; i < data[i1].common_prefix; ++i) {
  1000. llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
  1001. }
  1002. batch.logits[batch.n_tokens - 1] = true;
  1003. n_logits += 1;
  1004. for (int s = 0; s < 2; ++s) {
  1005. // TODO: end before the last token, no need to predict past the end of the sequences
  1006. for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
  1007. llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
  1008. n_logits += 1;
  1009. }
  1010. }
  1011. data[i1].i_logits = i_logits;
  1012. i_logits += n_logits;
  1013. n_cur += data[i1].required_tokens;
  1014. if (++i1 == data.size()) {
  1015. break;
  1016. }
  1017. }
  1018. if (i0 == i1) {
  1019. fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
  1020. return;
  1021. }
  1022. llama_kv_cache_clear(ctx);
  1023. // decode all tasks [i0, i1)
  1024. if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
  1025. fprintf(stderr, "%s: llama_decode() failed\n", __func__);
  1026. return;
  1027. }
  1028. eval_pairs.clear();
  1029. for (size_t i = i0; i < i1; ++i) {
  1030. auto & task = data[i];
  1031. const bool skip_choice =
  1032. task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
  1033. task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
  1034. const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
  1035. const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
  1036. size_t li = n_base1 - task.common_prefix;
  1037. for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
  1038. eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
  1039. }
  1040. const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
  1041. const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
  1042. // FIXME: this uses the wrong first logits when not skipping the choice word
  1043. li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
  1044. for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
  1045. eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
  1046. }
  1047. }
  1048. compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
  1049. size_t ir = 0;
  1050. for (size_t i = i0; i < i1; ++i) {
  1051. auto & task = data[i];
  1052. const bool skip_choice =
  1053. task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
  1054. task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
  1055. float score_1st = 0;
  1056. const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
  1057. const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
  1058. for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
  1059. score_1st += eval_results[ir++];
  1060. }
  1061. score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
  1062. float score_2nd = 0;
  1063. const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
  1064. const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
  1065. for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
  1066. score_2nd += eval_results[ir++];
  1067. }
  1068. score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
  1069. int result = score_1st > score_2nd ? 1 : 2;
  1070. if (result == task.answer) {
  1071. ++n_correct;
  1072. }
  1073. ++n_done;
  1074. // print the accumulated accuracy mean x 100
  1075. printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
  1076. fflush(stdout);
  1077. }
  1078. i0 = i1 - 1;
  1079. }
  1080. printf("\n");
  1081. if (n_done < 100) return;
  1082. const float p = 1.f*n_correct/n_done;
  1083. const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
  1084. printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
  1085. }
  1086. static bool deserialize_string(std::istream & in, std::string & str) {
  1087. uint32_t size;
  1088. if (!in.read((char *)&size, sizeof(size)).fail()) {
  1089. str.resize(size);
  1090. if (!in.read((char *)&str[0], size).fail()) return true;
  1091. }
  1092. return false;
  1093. }
  1094. struct multiple_choice_answers {
  1095. std::vector<std::string> answers;
  1096. std::vector<int> labels;
  1097. bool deserialize(std::istream& in) {
  1098. uint32_t n;
  1099. in.read((char *)&n, sizeof(n));
  1100. if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose
  1101. answers.resize(n);
  1102. labels.resize(n);
  1103. for (auto& a : answers) {
  1104. if (!deserialize_string(in, a)) return false;
  1105. }
  1106. in.read((char *)labels.data(), n*sizeof(int));
  1107. return !in.fail();
  1108. }
  1109. };
  1110. struct multiple_choice_task {
  1111. std::string question; // the question (or context that needs to be continued)
  1112. multiple_choice_answers mc1; // possible answers (continuations) with a single correct answer
  1113. multiple_choice_answers mc2; // possible answers (continuations) with multiple correct answers - not handled yet
  1114. bool deserialize(std::istream& in) {
  1115. if (!deserialize_string(in, question)) return false;
  1116. return mc1.deserialize(in) && mc2.deserialize(in);
  1117. }
  1118. // For evaluation
  1119. size_t i_logits; // starting index of logits in the llama_batch
  1120. size_t common_prefix; // max number of initial tokens that are the same in all sentences
  1121. size_t required_tokens; // needed number of tokens to evaluate all answers
  1122. std::vector<std::vector<llama_token>> seq_tokens;
  1123. std::vector<float> log_probs;
  1124. };
  1125. static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
  1126. if (task.question.empty() || task.mc1.answers.empty()) {
  1127. if (log_error) {
  1128. printf("%s: found bad task with empty question and/or answers\n", __func__);
  1129. }
  1130. return false;
  1131. }
  1132. task.seq_tokens.reserve(task.mc1.answers.size());
  1133. for (auto& answer : task.mc1.answers) {
  1134. if (answer.empty()) {
  1135. if (log_error) {
  1136. printf("%s: found empty answer\n", __func__);
  1137. }
  1138. return false;
  1139. }
  1140. task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true));
  1141. }
  1142. auto min_len = task.seq_tokens.front().size();
  1143. for (auto& seq : task.seq_tokens) {
  1144. min_len = std::min(min_len, seq.size());
  1145. }
  1146. task.common_prefix = 0;
  1147. for (size_t k = 0; k < min_len; ++k) {
  1148. auto token = task.seq_tokens[0][k];
  1149. bool all_same = true;
  1150. for (size_t i = 1; i < task.seq_tokens.size(); ++i) {
  1151. if (task.seq_tokens[i][k] != token) {
  1152. all_same = false;
  1153. break;
  1154. }
  1155. }
  1156. if (!all_same) {
  1157. break;
  1158. }
  1159. ++task.common_prefix;
  1160. }
  1161. task.required_tokens = task.common_prefix;
  1162. for (auto& seq : task.seq_tokens) {
  1163. task.required_tokens += seq.size() - task.common_prefix;
  1164. }
  1165. return true;
  1166. }
  1167. //
  1168. // Calculates score for multiple choice tasks with single correct answer from prompt.
  1169. // Commonly used LLM evaluation metrics of this type are
  1170. // * ARC
  1171. // * HellaSwag
  1172. // * MMLU
  1173. // * TruthfulQA
  1174. //
  1175. // Validation datasets for these 4 tests can be found at
  1176. // https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp
  1177. // The data for these datasets was extracted from
  1178. // git@hf.co:datasets/allenai/ai2_arc
  1179. // https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
  1180. // git@hf.co:datasets/Stevross/mmlu
  1181. // https://huggingface.co/datasets/truthful_qa
  1182. //
  1183. static void multiple_choice_score(llama_context * ctx, const gpt_params & params) {
  1184. std::istringstream strstream(params.prompt);
  1185. uint32_t n_task;
  1186. strstream.read((char *)&n_task, sizeof(n_task));
  1187. if (strstream.fail() || n_task == 0) {
  1188. printf("%s: no tasks\n", __func__);
  1189. return;
  1190. }
  1191. printf("%s: there are %u tasks in prompt\n", __func__, n_task);
  1192. std::vector<uint32_t> task_pos(n_task);
  1193. strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
  1194. if (strstream.fail()) {
  1195. printf("%s: failed to read task positions from prompt\n", __func__);
  1196. return;
  1197. }
  1198. std::vector<multiple_choice_task> tasks;
  1199. if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
  1200. // Use all tasks
  1201. tasks.resize(n_task);
  1202. printf("%s: reading tasks", __func__);
  1203. int n_dot = std::max((int) n_task/100, 1);
  1204. int i = 0;
  1205. for (auto& task : tasks) {
  1206. ++i;
  1207. if (!task.deserialize(strstream)) {
  1208. printf("%s: failed to read task %d of %u\n", __func__, i, n_task);
  1209. return;
  1210. }
  1211. if (i%n_dot == 0) printf(".");
  1212. }
  1213. printf("done\n");
  1214. }
  1215. else {
  1216. printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
  1217. std::mt19937 rng(1);
  1218. std::vector<int> aux(n_task);
  1219. for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
  1220. float scale = 1.f/(1.f + (float)std::mt19937::max());
  1221. tasks.resize(params.multiple_choice_tasks);
  1222. for (auto& task : tasks) {
  1223. int j = (int)(scale * rng() * aux.size());
  1224. int idx = aux[j];
  1225. aux[j] = aux.back();
  1226. aux.pop_back();
  1227. strstream.seekg(task_pos[idx], std::ios::beg);
  1228. if (!task.deserialize(strstream)) {
  1229. printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
  1230. return;
  1231. }
  1232. }
  1233. n_task = params.multiple_choice_tasks;
  1234. }
  1235. printf("%s: preparing task data", __func__);
  1236. fflush(stdout);
  1237. if (n_task > 500) {
  1238. printf("...");
  1239. fflush(stdout);
  1240. std::atomic<int> counter(0);
  1241. std::atomic<int> n_bad(0);
  1242. auto prepare = [&counter, &n_bad, &tasks, ctx] () {
  1243. int num_tasks = tasks.size();
  1244. int n_bad_local = 0;
  1245. while (true) {
  1246. int first = counter.fetch_add(K_TOKEN_CHUNK);
  1247. if (first >= num_tasks) {
  1248. if (n_bad_local > 0) n_bad += n_bad_local;
  1249. break;
  1250. }
  1251. int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
  1252. for (int i = first; i < last; ++i) {
  1253. if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local;
  1254. }
  1255. }
  1256. };
  1257. size_t max_thread = std::thread::hardware_concurrency();
  1258. max_thread = std::min(max_thread, (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK);
  1259. std::vector<std::thread> workers(max_thread-1);
  1260. for (auto& w : workers) w = std::thread(prepare);
  1261. prepare();
  1262. for (auto& w : workers) w.join();
  1263. printf("done\n");
  1264. fflush(stdout);
  1265. int nbad = n_bad;
  1266. if (nbad > 0) {
  1267. printf("%s: found %d malformed tasks\n", __func__, nbad);
  1268. return;
  1269. }
  1270. } else {
  1271. int n_dot = std::max((int) n_task/100, 1);
  1272. int i_task = 0;
  1273. for (auto& task : tasks) {
  1274. ++i_task;
  1275. if (!multiple_choice_prepare_one_task(ctx, task, true)) {
  1276. return;
  1277. }
  1278. if (i_task%n_dot == 0) {
  1279. printf(".");
  1280. fflush(stdout);
  1281. }
  1282. }
  1283. printf("done\n");
  1284. }
  1285. printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
  1286. printf("\ntask\tacc_norm\n");
  1287. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  1288. const int n_ctx = llama_n_ctx(ctx);
  1289. const int n_batch = params.n_batch;
  1290. const int max_tasks_per_batch = 32;
  1291. const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
  1292. llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
  1293. std::vector<float> tok_logits(n_vocab);
  1294. std::vector<float> batch_logits(n_vocab*n_ctx);
  1295. std::vector<std::pair<size_t, llama_token>> eval_pairs;
  1296. std::vector<float> eval_results;
  1297. std::vector<std::thread> workers(std::thread::hardware_concurrency());
  1298. std::vector<int> batch_indeces;
  1299. int n_done = 0;
  1300. int n_correct = 0;
  1301. int n_tot_answers = 0;
  1302. for (size_t i0 = 0; i0 < tasks.size(); i0++) {
  1303. int n_cur = 0;
  1304. size_t i1 = i0;
  1305. size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
  1306. llama_batch_clear(batch);
  1307. // batch as much tasks as possible into the available context
  1308. // each task has 4 unique sequence ids - one for each ending
  1309. // the common prefix is shared among the 4 sequences to save tokens
  1310. // we extract logits only from the last common token and from all ending tokens of each sequence
  1311. int s0 = 0;
  1312. while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
  1313. auto& cur_task = tasks[i1];
  1314. int n_logits = 0;
  1315. int num_answers = cur_task.seq_tokens.size();
  1316. if (s0 + num_answers > max_seq) {
  1317. break;
  1318. }
  1319. if (int(batch_indeces.size()) != num_answers) {
  1320. batch_indeces.resize(num_answers);
  1321. }
  1322. for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
  1323. for (size_t i = 0; i < cur_task.common_prefix; ++i) {
  1324. //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
  1325. llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
  1326. }
  1327. batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
  1328. n_logits += 1;
  1329. for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
  1330. const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
  1331. // TODO: don't evaluate the last token of each sequence
  1332. for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
  1333. const bool needs_logits = i < seq_tokens_size - 1;
  1334. llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
  1335. n_logits += needs_logits;
  1336. }
  1337. }
  1338. s0 += num_answers;
  1339. cur_task.i_logits = i_logits;
  1340. i_logits += n_logits;
  1341. n_cur += cur_task.required_tokens;
  1342. if (++i1 == tasks.size()) {
  1343. break;
  1344. }
  1345. }
  1346. if (i0 == i1) {
  1347. fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
  1348. return;
  1349. }
  1350. llama_kv_cache_clear(ctx);
  1351. // decode all tasks [i0, i1)
  1352. if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
  1353. fprintf(stderr, "%s: llama_decode() failed\n", __func__);
  1354. return;
  1355. }
  1356. // Compute log-probs in parallel
  1357. // First we collect all tasks
  1358. eval_pairs.clear();
  1359. for (size_t i = i0; i < i1; ++i) {
  1360. auto& cur_task = tasks[i];
  1361. size_t li = 1; // skip the last logit of the common prefix (computed separately below)
  1362. for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
  1363. for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
  1364. eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
  1365. }
  1366. }
  1367. }
  1368. // Then we do the actual calculation
  1369. compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
  1370. size_t ir = 0;
  1371. // compute the logprobs for each ending of the decoded tasks
  1372. for (size_t i = i0; i < i1; ++i) {
  1373. auto & cur_task = tasks[i];
  1374. //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
  1375. //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
  1376. // if (cur_task.mc1.labels[j] == 1) {
  1377. // printf("%d", j+1);
  1378. // }
  1379. //}
  1380. //printf("\n common_prefix: %zu\n", cur_task.common_prefix);
  1381. // get the logits of the last token of the common prefix
  1382. std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
  1383. const auto first_probs = softmax(tok_logits);
  1384. cur_task.log_probs.resize(cur_task.seq_tokens.size());
  1385. for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
  1386. size_t count = 1;
  1387. float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
  1388. for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
  1389. //printf(" %zu %g\n", ir, eval_results[ir]);
  1390. ++count;
  1391. log_prob += eval_results[ir++];
  1392. }
  1393. cur_task.log_probs[s] = log_prob / count;
  1394. //printf(" Final: %g\n", log_prob / count);
  1395. //printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
  1396. }
  1397. // Find the ending with maximum logprob
  1398. size_t logprob_max_idx = 0;
  1399. float logprob_max_val = cur_task.log_probs[0];
  1400. for (size_t s = 1; s < cur_task.log_probs.size(); s++) {
  1401. if (cur_task.log_probs[s] > logprob_max_val) {
  1402. logprob_max_val = cur_task.log_probs[s];
  1403. logprob_max_idx = s;
  1404. }
  1405. }
  1406. n_tot_answers += cur_task.log_probs.size();
  1407. if (cur_task.mc1.labels[logprob_max_idx] == 1) {
  1408. ++n_correct;
  1409. }
  1410. ++n_done;
  1411. // Print the accumulated accuracy mean x 100
  1412. printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
  1413. fflush(stdout);
  1414. }
  1415. i0 = i1 - 1;
  1416. }
  1417. llama_batch_free(batch);
  1418. if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
  1419. float p = 1.f*n_correct/n_done;
  1420. float sigma = sqrt(p*(1-p)/(n_done-1));
  1421. printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
  1422. p = 1.f*n_done/n_tot_answers;
  1423. sigma = sqrt(p*(1-p)/(n_done-1));
  1424. printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
  1425. printf("\n");
  1426. }
  1427. static void kl_divergence(llama_context * ctx, const gpt_params & params) {
  1428. if (params.logits_file.empty()) {
  1429. fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
  1430. return;
  1431. }
  1432. std::ifstream in(params.logits_file.c_str(), std::ios::binary);
  1433. if (!in) {
  1434. fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str());
  1435. return;
  1436. }
  1437. {
  1438. char check[9]; check[8] = 0;
  1439. in.read(check, 8);
  1440. if (in.fail() || strncmp("_logits_", check, 8) != 0) {
  1441. fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
  1442. return;
  1443. }
  1444. }
  1445. uint32_t n_ctx;
  1446. in.read((char *)&n_ctx, sizeof(n_ctx));
  1447. if (n_ctx > llama_n_ctx(ctx)) {
  1448. fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
  1449. __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
  1450. }
  1451. int n_vocab, n_chunk;
  1452. in.read((char *)&n_vocab, sizeof(n_vocab));
  1453. in.read((char *)&n_chunk, sizeof(n_chunk));
  1454. if (in.fail()) {
  1455. fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
  1456. return;
  1457. }
  1458. if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
  1459. fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
  1460. }
  1461. std::vector<llama_token> tokens(n_ctx * n_chunk);
  1462. if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
  1463. fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
  1464. return;
  1465. }
  1466. const int n_batch = params.n_batch;
  1467. const int num_batches = (n_ctx + n_batch - 1)/n_batch;
  1468. const int nv = 2*((n_vocab + 1)/2) + 4;
  1469. const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
  1470. GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
  1471. std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
  1472. std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
  1473. std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
  1474. std::vector<float> logits;
  1475. if (num_batches > 1) {
  1476. logits.reserve(n_ctx * n_vocab);
  1477. }
  1478. std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
  1479. auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
  1480. if (count < 1) {
  1481. return std::make_pair(0., 0.);
  1482. }
  1483. double f = sum/count;
  1484. double df = sum2/count - f*f;
  1485. df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.;
  1486. return std::make_pair(f, df);
  1487. };
  1488. auto covariance = [] (double suma, double sumb, double sumab, size_t count) {
  1489. if (count < 10) {
  1490. return 0.0;
  1491. }
  1492. double var = sumab/count - (suma/count)*(sumb/count);
  1493. var /= count - 1;
  1494. return var;
  1495. };
  1496. kl_divergence_result kld;
  1497. auto kld_ptr = kld_values.data();
  1498. auto p_diff_ptr = p_diff_values.data();
  1499. for (int i = 0; i < n_chunk; ++i) {
  1500. const int start = i * n_ctx;
  1501. const int end = start + n_ctx;
  1502. const auto t_start = std::chrono::high_resolution_clock::now();
  1503. if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
  1504. fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i);
  1505. return;
  1506. }
  1507. // clear the KV cache
  1508. llama_kv_cache_clear(ctx);
  1509. for (int j = 0; j < num_batches; ++j) {
  1510. const int batch_start = start + j * n_batch;
  1511. const int batch_size = std::min(end - batch_start, n_batch);
  1512. // save original token and restore it after eval
  1513. const auto token_org = tokens[batch_start];
  1514. // add BOS token for the first batch of each chunk
  1515. if (add_bos && j == 0) {
  1516. tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
  1517. }
  1518. // TODO: use llama_batch.logits instead of relying on logits_all == true
  1519. if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
  1520. fprintf(stderr, "%s : failed to eval\n", __func__);
  1521. return;
  1522. }
  1523. // restore the original token in case it was set to BOS
  1524. tokens[batch_start] = token_org;
  1525. if (num_batches > 1) {
  1526. const auto * batch_logits = llama_get_logits(ctx);
  1527. logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
  1528. }
  1529. }
  1530. const auto t_end = std::chrono::high_resolution_clock::now();
  1531. if (i == 0) {
  1532. const float t_total = std::chrono::duration<float>(t_end - t_start).count();
  1533. fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
  1534. int total_seconds = (int)(t_total * n_chunk);
  1535. if (total_seconds >= 60*60) {
  1536. fprintf(stderr, "%d hours ", total_seconds / (60*60));
  1537. total_seconds = total_seconds % (60*60);
  1538. }
  1539. fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
  1540. printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
  1541. }
  1542. const int first = n_ctx/2;
  1543. const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
  1544. process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
  1545. workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
  1546. p_diff_ptr += n_ctx - 1 - first;
  1547. kld_ptr += n_ctx - 1 - first;
  1548. printf("%4d", i+1);
  1549. auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
  1550. const double ppl_val = exp(log_ppl.first);
  1551. const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
  1552. printf(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
  1553. auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
  1554. const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
  1555. const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
  1556. const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
  1557. printf(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
  1558. auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
  1559. printf(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
  1560. auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
  1561. const double p_diff_rms_val = sqrt(p_diff_mse.first);
  1562. const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
  1563. printf(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
  1564. double p_top_val = 1.*kld.n_same_top/kld.count;
  1565. double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
  1566. printf(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
  1567. printf("\n");
  1568. fflush(stdout);
  1569. logits.clear();
  1570. }
  1571. printf("\n");
  1572. if (kld.count < 100) return; // we do not wish to do statistics on so few values
  1573. std::sort(kld_values.begin(), kld_values.end());
  1574. std::sort(p_diff_values.begin(), p_diff_values.end());
  1575. printf("====== Perplexity statistics ======\n");
  1576. auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
  1577. const double ppl_val = exp(log_ppl.first);
  1578. const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
  1579. printf("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
  1580. auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
  1581. const double ppl_base_val = exp(log_ppl_base.first);
  1582. const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
  1583. printf("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
  1584. const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
  1585. // printf("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
  1586. const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
  1587. printf("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
  1588. const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
  1589. const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
  1590. printf("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
  1591. const double ppl_ratio_val = exp(log_ppl_ratio_val);
  1592. const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
  1593. printf("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
  1594. const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
  1595. const double ppl_diff_val = ppl_val - ppl_base_val;
  1596. const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
  1597. printf("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
  1598. printf("\n");
  1599. printf("====== KL divergence statistics ======\n");
  1600. auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
  1601. printf("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
  1602. auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
  1603. : kld_values[kld_values.size()/2];
  1604. auto percentile = [] (std::vector<float> values, float fraction) {
  1605. if (fraction <= 0) return values.front();
  1606. if (fraction >= 1) return values.back();
  1607. float p = fraction*(values.size() - 1);
  1608. size_t ip = size_t(p); p -= ip;
  1609. return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)];
  1610. };
  1611. printf("Maximum KLD: %10.6f\n", kld_values.back());
  1612. printf("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
  1613. printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
  1614. printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
  1615. printf("Median KLD: %10.6f\n", kld_median);
  1616. printf("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
  1617. printf(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
  1618. printf(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
  1619. printf("Minimum KLD: %10.6f\n", kld_values.front());
  1620. printf("\n");
  1621. printf("====== Token probability statistics ======\n");
  1622. auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
  1623. printf("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second);
  1624. auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
  1625. : p_diff_values[p_diff_values.size()/2];
  1626. printf("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back());
  1627. printf("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
  1628. printf("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
  1629. printf("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
  1630. printf("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
  1631. printf("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
  1632. printf("Median Δp: %6.3lf%%\n", 100.0*p_diff_median);
  1633. printf("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
  1634. printf("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
  1635. printf(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
  1636. printf(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
  1637. printf(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
  1638. printf("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front());
  1639. auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
  1640. // printf("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
  1641. const double p_diff_rms_val = sqrt(p_diff_mse.first);
  1642. const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
  1643. printf("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
  1644. const double same_top_p = 1.0*kld.n_same_top/kld.count;
  1645. printf("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
  1646. }
  1647. int main(int argc, char ** argv) {
  1648. gpt_params params;
  1649. params.n_ctx = 512;
  1650. params.logits_all = true;
  1651. if (!gpt_params_parse(argc, argv, params)) {
  1652. gpt_params_print_usage(argc, argv, params);
  1653. return 1;
  1654. }
  1655. const int32_t n_ctx = params.n_ctx;
  1656. if (n_ctx <= 0) {
  1657. fprintf(stderr, "%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
  1658. return 1;
  1659. }
  1660. const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
  1661. if (ppl) {
  1662. const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
  1663. const int32_t n_kv = n_seq * n_ctx;
  1664. params.n_parallel = n_seq;
  1665. params.n_ctx = n_kv;
  1666. params.n_batch = std::min(params.n_batch, n_kv);
  1667. } else {
  1668. params.n_batch = std::min(params.n_batch, params.n_ctx);
  1669. if (params.kl_divergence) {
  1670. params.n_parallel = 1;
  1671. } else {
  1672. // ensure there's at least enough seq_ids for HellaSwag
  1673. params.n_parallel = std::max(4, params.n_parallel);
  1674. }
  1675. }
  1676. if (params.ppl_stride > 0) {
  1677. fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
  1678. params.n_ctx, params.n_ctx + params.ppl_stride/2);
  1679. params.n_ctx += params.ppl_stride/2;
  1680. }
  1681. print_build_info();
  1682. LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
  1683. llama_backend_init();
  1684. llama_numa_init(params.numa);
  1685. // load the model and apply lora adapter, if any
  1686. llama_init_result llama_init = llama_init_from_gpt_params(params);
  1687. llama_model * model = llama_init.model;
  1688. llama_context * ctx = llama_init.context;
  1689. if (model == NULL) {
  1690. fprintf(stderr, "%s: error: unable to load model\n", __func__);
  1691. return 1;
  1692. }
  1693. const int n_ctx_train = llama_n_ctx_train(model);
  1694. if (params.n_ctx > n_ctx_train) {
  1695. fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
  1696. __func__, n_ctx_train, params.n_ctx);
  1697. }
  1698. // print system information
  1699. {
  1700. fprintf(stderr, "\n");
  1701. fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
  1702. }
  1703. struct results_perplexity results;
  1704. if (params.hellaswag) {
  1705. hellaswag_score(ctx, params);
  1706. } else if (params.winogrande) {
  1707. winogrande_score(ctx, params);
  1708. } else if (params.multiple_choice) {
  1709. multiple_choice_score(ctx, params);
  1710. } else if (params.kl_divergence) {
  1711. kl_divergence(ctx, params);
  1712. } else {
  1713. results = perplexity(ctx, params, n_ctx);
  1714. }
  1715. LOG_TEE("\n");
  1716. llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
  1717. write_logfile(ctx, params, model, results);
  1718. llama_free(ctx);
  1719. llama_free_model(model);
  1720. llama_backend_free();
  1721. return 0;
  1722. }