llama-sampling.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. #include "llama-sampling.h"
  2. #include <algorithm>
  3. #include <cstring>
  4. #include <ctime>
  5. #include <cfloat>
  6. #include <numeric>
  7. #include <unordered_map>
  8. static void llama_log_softmax(float * array, size_t size) {
  9. float max_l = *std::max_element(array, array + size);
  10. float sum = 0.f;
  11. for (size_t i = 0; i < size; ++i) {
  12. float p = expf(array[i] - max_l);
  13. sum += p;
  14. array[i] = p;
  15. }
  16. for (size_t i = 0; i < size; ++i) {
  17. array[i] = logf(array[i] / sum);
  18. }
  19. }
  20. void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
  21. if (seed == LLAMA_DEFAULT_SEED) {
  22. seed = time(NULL);
  23. }
  24. smpl->rng.seed(seed);
  25. }
  26. void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  27. GGML_ASSERT(candidates->size > 0);
  28. const int64_t t_start_sample_us = ggml_time_us();
  29. // Sort the logits in descending order
  30. if (!candidates->sorted) {
  31. std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  32. return a.logit > b.logit;
  33. });
  34. candidates->sorted = true;
  35. }
  36. float max_l = candidates->data[0].logit;
  37. float cum_sum = 0.0f;
  38. for (size_t i = 0; i < candidates->size; ++i) {
  39. float p = expf(candidates->data[i].logit - max_l);
  40. candidates->data[i].p = p;
  41. cum_sum += p;
  42. }
  43. for (size_t i = 0; i < candidates->size; ++i) {
  44. candidates->data[i].p /= cum_sum;
  45. }
  46. if (smpl) {
  47. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  48. }
  49. }
  50. void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
  51. // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
  52. // if (k >= (int32_t)candidates->size) {
  53. // return;
  54. // }
  55. const int64_t t_start_sample_us = ggml_time_us();
  56. if (k <= 0) {
  57. k = candidates->size;
  58. }
  59. k = std::max(k, (int) min_keep);
  60. k = std::min(k, (int) candidates->size);
  61. // Sort scores in descending order
  62. if (!candidates->sorted) {
  63. auto comp = [](const llama_token_data & a, const llama_token_data & b) {
  64. return a.logit > b.logit;
  65. };
  66. if (k <= 128) {
  67. std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
  68. } else {
  69. constexpr int nbuckets = 128;
  70. constexpr float bucket_low = -10.0f;
  71. constexpr float bucket_high = 10.0f;
  72. constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
  73. constexpr float bucket_inter = -bucket_low * bucket_scale;
  74. std::vector<int> bucket_idx(candidates->size);
  75. std::vector<int> histo(nbuckets, 0);
  76. for (int i = 0; i < (int)candidates->size; ++i) {
  77. const float val = candidates->data[i].logit;
  78. int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
  79. ib = std::max(0, std::min(nbuckets-1, ib));
  80. bucket_idx[i] = ib;
  81. ++histo[ib];
  82. }
  83. int nhave = 0;
  84. int ib = nbuckets - 1;
  85. for ( ; ib >= 0; --ib) {
  86. nhave += histo[ib];
  87. if (nhave >= k) break;
  88. }
  89. std::vector<llama_token_data> tmp_tokens(nhave);
  90. auto ptr = tmp_tokens.data();
  91. std::vector<llama_token_data*> bucket_ptrs;
  92. bucket_ptrs.reserve(nbuckets - ib);
  93. for (int j = nbuckets - 1; j >= ib; --j) {
  94. bucket_ptrs.push_back(ptr);
  95. ptr += histo[j];
  96. }
  97. for (int i = 0; i < (int)candidates->size; ++i) {
  98. int j = bucket_idx[i];
  99. if (j >= ib) {
  100. *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
  101. }
  102. }
  103. ptr = tmp_tokens.data();
  104. int ndone = 0;
  105. for (int j = nbuckets-1; j > ib; --j) {
  106. std::sort(ptr, ptr + histo[j], comp);
  107. ptr += histo[j];
  108. ndone += histo[j];
  109. }
  110. std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
  111. std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
  112. }
  113. candidates->sorted = true;
  114. }
  115. candidates->size = k;
  116. if (smpl) {
  117. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  118. }
  119. }
  120. void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  121. if (p >= 1.0f) {
  122. return;
  123. }
  124. llama_sample_softmax_impl(smpl, candidates);
  125. const int64_t t_start_sample_us = ggml_time_us();
  126. // Compute the cumulative probabilities
  127. float cum_sum = 0.0f;
  128. size_t last_idx = candidates->size;
  129. for (size_t i = 0; i < candidates->size; ++i) {
  130. cum_sum += candidates->data[i].p;
  131. // Check if the running sum is at least p or if we have kept at least min_keep tokens
  132. // we set the last index to i+1 to indicate that the current iterate should be included in the set
  133. if (cum_sum >= p && i + 1 >= min_keep) {
  134. last_idx = i + 1;
  135. break;
  136. }
  137. }
  138. // Resize the output vector to keep only the top-p tokens
  139. candidates->size = last_idx;
  140. if (smpl) {
  141. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  142. }
  143. }
  144. void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  145. if (p <= 0.0f || !candidates->size) {
  146. return;
  147. }
  148. const int64_t t_start_sample_us = ggml_time_us();
  149. bool min_p_applied = false;
  150. // if the candidates aren't sorted, try the unsorted implementation first
  151. if (!candidates->sorted) {
  152. std::vector<llama_token_data> filtered_tokens;
  153. float max_logit = -FLT_MAX;
  154. for (size_t i = 0; i < candidates->size; ++i) {
  155. max_logit = std::max(max_logit, candidates->data[i].logit);
  156. }
  157. const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
  158. for (size_t i = 0; i < candidates->size; ++i) {
  159. if (candidates->data[i].logit >= min_logit) {
  160. filtered_tokens.push_back(candidates->data[i]);
  161. }
  162. }
  163. // if we have enough values the operation was a success
  164. if (filtered_tokens.size() >= min_keep) {
  165. memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
  166. candidates->size = filtered_tokens.size();
  167. min_p_applied = true;
  168. }
  169. }
  170. // if the candidates are sorted or the unsorted implementation failed, use this implementation
  171. if (!min_p_applied) {
  172. // Sort the logits in descending order
  173. if (!candidates->sorted) {
  174. std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  175. return a.logit > b.logit;
  176. });
  177. candidates->sorted = true;
  178. }
  179. const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
  180. size_t i = 1; // first token always matches
  181. for (; i < candidates->size; ++i) {
  182. if (candidates->data[i].logit < min_logit && i >= min_keep) {
  183. break; // prob too small
  184. }
  185. }
  186. // Resize the output vector to keep only the matching tokens
  187. candidates->size = i;
  188. }
  189. if (smpl) {
  190. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  191. }
  192. }
  193. void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
  194. if (z >= 1.0f || candidates->size <= 2) {
  195. return;
  196. }
  197. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  198. const int64_t t_start_sample_us = ggml_time_us();
  199. // Compute the first and second derivatives
  200. std::vector<float> first_derivatives(candidates->size - 1);
  201. std::vector<float> second_derivatives(candidates->size - 2);
  202. for (size_t i = 0; i < first_derivatives.size(); ++i) {
  203. first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
  204. }
  205. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  206. second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
  207. }
  208. // Calculate absolute value of second derivatives
  209. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  210. second_derivatives[i] = std::abs(second_derivatives[i]);
  211. }
  212. // Normalize the second derivatives
  213. {
  214. const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
  215. if (second_derivatives_sum > 1e-6f) {
  216. for (float & value : second_derivatives) {
  217. value /= second_derivatives_sum;
  218. }
  219. } else {
  220. for (float & value : second_derivatives) {
  221. value = 1.0f / second_derivatives.size();
  222. }
  223. }
  224. }
  225. float cum_sum = 0.0f;
  226. size_t last_idx = candidates->size;
  227. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  228. cum_sum += second_derivatives[i];
  229. // Check if the running sum is greater than z or if we have kept at least min_keep tokens
  230. if (cum_sum > z && i >= min_keep) {
  231. last_idx = i;
  232. break;
  233. }
  234. }
  235. // Resize the output vector to keep only the tokens above the tail location
  236. candidates->size = last_idx;
  237. if (smpl) {
  238. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  239. }
  240. }
  241. void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  242. // Reference implementation:
  243. // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
  244. if (p >= 1.0f) {
  245. return;
  246. }
  247. // Compute the softmax of logits and calculate entropy
  248. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  249. const int64_t t_start_sample_us = ggml_time_us();
  250. float entropy = 0.0f;
  251. for (size_t i = 0; i < candidates->size; ++i) {
  252. entropy += -candidates->data[i].p * logf(candidates->data[i].p);
  253. }
  254. // Compute the absolute difference between negative log probability and entropy for each candidate
  255. std::vector<float> shifted_scores;
  256. for (size_t i = 0; i < candidates->size; ++i) {
  257. float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
  258. shifted_scores.push_back(shifted_score);
  259. }
  260. // Sort tokens based on the shifted_scores and their corresponding indices
  261. std::vector<size_t> indices(candidates->size);
  262. std::iota(indices.begin(), indices.end(), 0);
  263. std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
  264. return shifted_scores[a] < shifted_scores[b];
  265. });
  266. // Compute the cumulative probabilities
  267. float cum_sum = 0.0f;
  268. size_t last_idx = indices.size();
  269. for (size_t i = 0; i < indices.size(); ++i) {
  270. size_t idx = indices[i];
  271. cum_sum += candidates->data[idx].p;
  272. // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
  273. if (cum_sum > p && i >= min_keep - 1) {
  274. last_idx = i + 1;
  275. break;
  276. }
  277. }
  278. // Resize the output vector to keep only the locally typical tokens
  279. std::vector<llama_token_data> new_candidates;
  280. for (size_t i = 0; i < last_idx; ++i) {
  281. size_t idx = indices[i];
  282. new_candidates.push_back(candidates->data[idx]);
  283. }
  284. // Replace the data in candidates with the new_candidates data
  285. std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
  286. candidates->size = new_candidates.size();
  287. candidates->sorted = false;
  288. if (smpl) {
  289. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  290. }
  291. }
  292. void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
  293. const int64_t t_start_sample_us = ggml_time_us();
  294. // no need to do anything if there is only one (or zero) candidates
  295. if(candidates->size <= 1) {
  296. return;
  297. }
  298. // Calculate maximum possible entropy
  299. float max_entropy = -logf(1.0f / candidates->size);
  300. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  301. // Calculate entropy of the softmax probabilities
  302. float entropy = 0.0f;
  303. for (size_t i = 0; i < candidates->size; ++i) {
  304. float prob = candidates->data[i].p;
  305. if (prob > 0.0f) { // Ensure no log(0)
  306. entropy -= prob * logf(prob);
  307. }
  308. }
  309. // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
  310. float normalized_entropy = entropy / max_entropy;
  311. // Map the normalized entropy to the desired temperature range using the power function
  312. float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
  313. #ifdef DEBUG
  314. LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
  315. LLAMA_LOG_INFO("Entropy: %f\n", entropy);
  316. LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
  317. LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
  318. LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
  319. LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
  320. #endif
  321. // Apply the dynamically calculated temperature scaling
  322. for (size_t i = 0; i < candidates->size; ++i) {
  323. candidates->data[i].logit /= dyn_temp;
  324. }
  325. // Re-compute softmax probabilities after scaling logits with dynamic temperature
  326. double max_l_double = candidates->data[0].logit;
  327. double cum_sum_double = 0.0;
  328. for (size_t i = 0; i < candidates->size; ++i) {
  329. double p = exp(candidates->data[i].logit - max_l_double);
  330. candidates->data[i].p = p; // Store the scaled probability
  331. cum_sum_double += p;
  332. }
  333. for (size_t i = 0; i < candidates->size; ++i) {
  334. candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
  335. }
  336. #ifdef DEBUG
  337. // Print the updated top 25 probabilities after temperature scaling
  338. LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
  339. for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
  340. LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
  341. }
  342. #endif
  343. if (smpl) {
  344. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  345. }
  346. }
  347. void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
  348. const int64_t t_start_sample_us = ggml_time_us();
  349. for (size_t i = 0; i < candidates->size; ++i) {
  350. candidates->data[i].logit /= temp;
  351. }
  352. if (smpl) {
  353. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  354. }
  355. }
  356. void llama_sample_repetition_penalties_impl(
  357. struct llama_sampling * smpl,
  358. llama_token_data_array * candidates,
  359. const llama_token * last_tokens,
  360. size_t penalty_last_n,
  361. float penalty_repeat,
  362. float penalty_freq,
  363. float penalty_present) {
  364. if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
  365. return;
  366. }
  367. const int64_t t_start_sample_us = ggml_time_us();
  368. // Create a frequency map to count occurrences of each token in last_tokens
  369. std::unordered_map<llama_token, int> token_count;
  370. for (size_t i = 0; i < penalty_last_n; ++i) {
  371. token_count[last_tokens[i]]++;
  372. }
  373. // Apply frequency and presence penalties to the candidates
  374. for (size_t i = 0; i < candidates->size; ++i) {
  375. const auto token_iter = token_count.find(candidates->data[i].id);
  376. if (token_iter == token_count.end()) {
  377. continue;
  378. }
  379. const int count = token_iter->second;
  380. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
  381. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
  382. if (candidates->data[i].logit <= 0) {
  383. candidates->data[i].logit *= penalty_repeat;
  384. } else {
  385. candidates->data[i].logit /= penalty_repeat;
  386. }
  387. candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
  388. }
  389. candidates->sorted = false;
  390. if (smpl) {
  391. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  392. }
  393. }
  394. void llama_sample_apply_guidance_impl(
  395. struct llama_sampling * smpl,
  396. float * logits,
  397. float * logits_guidance,
  398. float scale) {
  399. GGML_ASSERT(smpl);
  400. const auto t_start_sample_us = ggml_time_us();
  401. const auto n_vocab = smpl->n_vocab;
  402. llama_log_softmax(logits, n_vocab);
  403. llama_log_softmax(logits_guidance, n_vocab);
  404. for (int i = 0; i < n_vocab; ++i) {
  405. auto & l = logits[i];
  406. const auto & g = logits_guidance[i];
  407. l = scale * (l - g) + g;
  408. }
  409. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  410. }
  411. llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
  412. GGML_ASSERT(smpl);
  413. const int32_t n_vocab = float(smpl->n_vocab);
  414. int64_t t_start_sample_us = ggml_time_us();
  415. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  416. // Estimate s_hat using the most probable m tokens
  417. float s_hat = 0.0;
  418. float sum_ti_bi = 0.0;
  419. float sum_ti_sq = 0.0;
  420. for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
  421. float t_i = logf(float(i + 2) / float(i + 1));
  422. float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
  423. sum_ti_bi += t_i * b_i;
  424. sum_ti_sq += t_i * t_i;
  425. }
  426. s_hat = sum_ti_bi / sum_ti_sq;
  427. // Compute k from the estimated s_hat and target surprise value
  428. float epsilon_hat = s_hat - 1;
  429. float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
  430. // Sample the next word X using top-k sampling
  431. llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
  432. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  433. llama_token X = llama_sample_token_impl(smpl, candidates);
  434. t_start_sample_us = ggml_time_us();
  435. // Compute error as the difference between observed surprise and target surprise value
  436. size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  437. return candidate.id == X;
  438. }));
  439. float observed_surprise = -log2f(candidates->data[X_idx].p);
  440. float e = observed_surprise - tau;
  441. // Update mu using the learning rate and error
  442. *mu = *mu - eta * e;
  443. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  444. return X;
  445. }
  446. llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
  447. int64_t t_start_sample_us;
  448. t_start_sample_us = ggml_time_us();
  449. llama_sample_softmax_impl(smpl, candidates);
  450. // Truncate the words with surprise values greater than mu
  451. candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  452. return -log2f(candidate.p) > *mu;
  453. }));
  454. if (candidates->size == 0) {
  455. candidates->size = 1;
  456. }
  457. if (smpl) {
  458. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  459. }
  460. // Normalize the probabilities of the remaining words
  461. llama_sample_softmax_impl(smpl, candidates);
  462. // Sample the next word X from the remaining words
  463. llama_token X = llama_sample_token_impl(smpl, candidates);
  464. t_start_sample_us = ggml_time_us();
  465. // Compute error as the difference between observed surprise and target surprise value
  466. size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  467. return candidate.id == X;
  468. }));
  469. float observed_surprise = -log2f(candidates->data[X_idx].p);
  470. float e = observed_surprise - tau;
  471. // Update mu using the learning rate and error
  472. *mu = *mu - eta * e;
  473. if (smpl) {
  474. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  475. }
  476. return X;
  477. }
  478. llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  479. const int64_t t_start_sample_us = ggml_time_us();
  480. // Find max element
  481. auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  482. return a.logit < b.logit;
  483. });
  484. llama_token result = max_iter->id;
  485. if (smpl) {
  486. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  487. smpl->n_sample++;
  488. }
  489. return result;
  490. }
  491. llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
  492. GGML_ASSERT(smpl);
  493. const int64_t t_start_sample_us = ggml_time_us();
  494. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  495. std::vector<float> probs;
  496. probs.reserve(candidates->size);
  497. for (size_t i = 0; i < candidates->size; ++i) {
  498. probs.push_back(candidates->data[i].p);
  499. }
  500. std::discrete_distribution<> dist(probs.begin(), probs.end());
  501. int idx = dist(rng);
  502. llama_token result = candidates->data[idx].id;
  503. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  504. smpl->n_sample++;
  505. return result;
  506. }
  507. llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  508. return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
  509. }