sampling.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. #include "sampling.h"
  2. #include "common.h"
  3. // the ring buffer works similarly to std::deque, but with a fixed capacity
  4. // TODO: deduplicate with llama-impl.h
  5. template<typename T>
  6. struct ring_buffer {
  7. ring_buffer(size_t cap) : capacity(cap), data(cap) {}
  8. T & front() {
  9. if (sz == 0) {
  10. throw std::runtime_error("ring buffer is empty");
  11. }
  12. return data[first];
  13. }
  14. const T & front() const {
  15. if (sz == 0) {
  16. throw std::runtime_error("ring buffer is empty");
  17. }
  18. return data[first];
  19. }
  20. T & back() {
  21. if (sz == 0) {
  22. throw std::runtime_error("ring buffer is empty");
  23. }
  24. return data[pos];
  25. }
  26. const T & back() const {
  27. if (sz == 0) {
  28. throw std::runtime_error("ring buffer is empty");
  29. }
  30. return data[pos];
  31. }
  32. void push_back(const T & value) {
  33. if (sz == capacity) {
  34. // advance the start when buffer is full
  35. first = (first + 1) % capacity;
  36. } else {
  37. sz++;
  38. }
  39. data[pos] = value;
  40. pos = (pos + 1) % capacity;
  41. }
  42. T pop_front() {
  43. if (sz == 0) {
  44. throw std::runtime_error("ring buffer is empty");
  45. }
  46. T value = data[first];
  47. first = (first + 1) % capacity;
  48. sz--;
  49. return value;
  50. }
  51. const T & rat(size_t i) const {
  52. if (i >= sz) {
  53. throw std::runtime_error("ring buffer: index out of bounds");
  54. }
  55. return data[(first + sz - i - 1) % capacity];
  56. }
  57. std::vector<T> to_vector() const {
  58. std::vector<T> result;
  59. result.reserve(sz);
  60. for (size_t i = 0; i < sz; i++) {
  61. result.push_back(data[(first + i) % capacity]);
  62. }
  63. return result;
  64. }
  65. void clear() {
  66. // here only reset the status of the buffer
  67. sz = 0;
  68. first = 0;
  69. pos = 0;
  70. }
  71. bool empty() const {
  72. return sz == 0;
  73. }
  74. size_t size() const {
  75. return sz;
  76. }
  77. size_t capacity = 0;
  78. size_t sz = 0;
  79. size_t first = 0;
  80. size_t pos = 0;
  81. std::vector<T> data;
  82. };
  83. struct gpt_sampler {
  84. gpt_sampler_params params;
  85. struct llama_sampler * grmr;
  86. struct llama_sampler * chain;
  87. ring_buffer<llama_token> prev;
  88. std::vector<llama_token_data> cur;
  89. llama_token_data_array cur_p;
  90. void set_logits(struct llama_context * ctx, int idx) {
  91. const auto * logits = llama_get_logits_ith(ctx, idx);
  92. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  93. cur.resize(n_vocab);
  94. for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
  95. cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
  96. }
  97. cur_p = { cur.data(), cur.size(), -1, false };
  98. }
  99. };
  100. std::string gpt_sampler_params::print() const {
  101. char result[1024];
  102. snprintf(result, sizeof(result),
  103. "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
  104. "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
  105. "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
  106. penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
  107. top_k, tfs_z, top_p, min_p, typ_p, temp,
  108. mirostat, mirostat_eta, mirostat_tau);
  109. return std::string(result);
  110. }
  111. struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
  112. llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
  113. lparams.no_perf = false; // TODO: control via params
  114. auto * result = new gpt_sampler {
  115. /* .params = */ params,
  116. /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
  117. /* .chain = */ llama_sampler_chain_init(lparams),
  118. /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
  119. /* .cur = */ {},
  120. /* .cur_p = */ {},
  121. };
  122. llama_sampler_chain_add(result->chain,
  123. llama_sampler_init_logit_bias(
  124. llama_n_vocab(model),
  125. params.logit_bias.size(),
  126. params.logit_bias.data()));
  127. llama_sampler_chain_add(result->chain,
  128. llama_sampler_init_penalties(
  129. llama_n_vocab (model),
  130. llama_token_eos(model),
  131. llama_token_nl (model),
  132. params.penalty_last_n,
  133. params.penalty_repeat,
  134. params.penalty_freq,
  135. params.penalty_present,
  136. params.penalize_nl,
  137. params.ignore_eos));
  138. if (params.temp > 0.0f) {
  139. if (params.mirostat == 0) {
  140. for (const auto & cnstr : params.samplers) {
  141. switch (cnstr) {
  142. case GPT_SAMPLER_TYPE_TOP_K:
  143. llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
  144. break;
  145. case GPT_SAMPLER_TYPE_TOP_P:
  146. llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
  147. break;
  148. case GPT_SAMPLER_TYPE_MIN_P:
  149. llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
  150. break;
  151. case GPT_SAMPLER_TYPE_TFS_Z:
  152. llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
  153. break;
  154. case GPT_SAMPLER_TYPE_TYPICAL_P:
  155. llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
  156. break;
  157. case GPT_SAMPLER_TYPE_TEMPERATURE:
  158. llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
  159. break;
  160. default:
  161. GGML_ASSERT(false && "unknown sampler type");
  162. }
  163. }
  164. llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
  165. llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
  166. } else if (params.mirostat == 1) {
  167. llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
  168. llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
  169. } else if (params.mirostat == 2) {
  170. llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
  171. llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
  172. } else {
  173. GGML_ASSERT(false && "unknown mirostat version");
  174. }
  175. } else {
  176. llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
  177. llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
  178. }
  179. return result;
  180. }
  181. void gpt_sampler_free(struct gpt_sampler * gsmpl) {
  182. if (gsmpl) {
  183. llama_sampler_free(gsmpl->grmr);
  184. llama_sampler_free(gsmpl->chain);
  185. delete gsmpl;
  186. }
  187. }
  188. void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
  189. if (accept_grammar) {
  190. llama_sampler_accept(gsmpl->grmr, token);
  191. }
  192. llama_sampler_accept(gsmpl->chain, token);
  193. gsmpl->prev.push_back(token);
  194. }
  195. void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
  196. llama_sampler_reset(gsmpl->grmr);
  197. llama_sampler_reset(gsmpl->chain);
  198. }
  199. struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
  200. return new gpt_sampler {
  201. /* .params = */ gsmpl->params,
  202. /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
  203. /* .chain = */ llama_sampler_clone(gsmpl->chain),
  204. /* .prev = */ gsmpl->prev,
  205. /* .cur = */ gsmpl->cur,
  206. /* .cur_p = */ gsmpl->cur_p,
  207. };
  208. }
  209. void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
  210. // TODO: measure grammar performance
  211. if (gsmpl) {
  212. llama_perf_print(gsmpl->chain, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
  213. }
  214. if (ctx) {
  215. llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
  216. }
  217. }
  218. llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
  219. gsmpl->set_logits(ctx, idx);
  220. auto & grmr = gsmpl->grmr;
  221. auto & chain = gsmpl->chain;
  222. auto & cur_p = gsmpl->cur_p; // initialized by set_logits
  223. if (grammar_first) {
  224. llama_sampler_apply(grmr, &cur_p);
  225. }
  226. llama_sampler_apply(chain, &cur_p);
  227. GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
  228. const llama_token id = cur_p.data[cur_p.selected].id;
  229. if (grammar_first) {
  230. return id;
  231. }
  232. // check if it the sampled token fits the grammar
  233. {
  234. llama_token_data single_token_data = { id, 1.0f, 0.0f };
  235. llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
  236. llama_sampler_apply(grmr, &single_token_data_array);
  237. const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
  238. if (is_valid) {
  239. return id;
  240. }
  241. }
  242. // resampling:
  243. // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
  244. gsmpl->set_logits(ctx, idx);
  245. llama_sampler_apply(grmr, &cur_p);
  246. llama_sampler_apply(chain, &cur_p);
  247. GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
  248. return cur_p.data[cur_p.selected].id;
  249. }
  250. // helpers
  251. llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
  252. return &gsmpl->cur_p;
  253. }
  254. llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
  255. return gsmpl->prev.rat(0);
  256. }
  257. std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
  258. std::string result = "\tlogits ";
  259. for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
  260. const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
  261. result += std::string("-> ") + llama_sampler_name(smpl) + " ";
  262. }
  263. return result;
  264. }
  265. std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
  266. n = std::min(n, (int) gsmpl->prev.size());
  267. if (n <= 0) {
  268. return "";
  269. }
  270. std::string result;
  271. result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
  272. for (int i = n - 1; i >= 0; i--) {
  273. const llama_token id = gsmpl->prev.rat(i);
  274. GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
  275. result += llama_token_to_piece(ctx_main, id);
  276. }
  277. return result;
  278. }
  279. char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
  280. switch (cnstr) {
  281. case GPT_SAMPLER_TYPE_TOP_K: return 'k';
  282. case GPT_SAMPLER_TYPE_TFS_Z: return 'f';
  283. case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y';
  284. case GPT_SAMPLER_TYPE_TOP_P: return 'p';
  285. case GPT_SAMPLER_TYPE_MIN_P: return 'm';
  286. case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
  287. default : return '?';
  288. }
  289. }
  290. std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
  291. switch (cnstr) {
  292. case GPT_SAMPLER_TYPE_TOP_K: return "top_k";
  293. case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z";
  294. case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
  295. case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
  296. case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
  297. case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
  298. default : return "";
  299. }
  300. }
  301. std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
  302. std::unordered_map<std::string, gpt_sampler_type> sampler_canonical_name_map {
  303. { "top_k", GPT_SAMPLER_TYPE_TOP_K },
  304. { "top_p", GPT_SAMPLER_TYPE_TOP_P },
  305. { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P },
  306. { "min_p", GPT_SAMPLER_TYPE_MIN_P },
  307. { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
  308. { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
  309. };
  310. // since samplers names are written multiple ways
  311. // make it ready for both system names and input names
  312. std::unordered_map<std::string, gpt_sampler_type> sampler_alt_name_map {
  313. { "top-k", GPT_SAMPLER_TYPE_TOP_K },
  314. { "top-p", GPT_SAMPLER_TYPE_TOP_P },
  315. { "nucleus", GPT_SAMPLER_TYPE_TOP_P },
  316. { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P },
  317. { "typical", GPT_SAMPLER_TYPE_TYPICAL_P },
  318. { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P },
  319. { "typ", GPT_SAMPLER_TYPE_TYPICAL_P },
  320. { "min-p", GPT_SAMPLER_TYPE_MIN_P },
  321. { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z },
  322. { "tfs", GPT_SAMPLER_TYPE_TFS_Z },
  323. { "temp", GPT_SAMPLER_TYPE_TEMPERATURE },
  324. };
  325. std::vector<gpt_sampler_type> samplers;
  326. samplers.reserve(names.size());
  327. for (const auto & name : names) {
  328. auto sampler = sampler_canonical_name_map.find(name);
  329. if (sampler != sampler_canonical_name_map.end()) {
  330. samplers.push_back(sampler->second);
  331. } else {
  332. if (allow_alt_names) {
  333. sampler = sampler_alt_name_map.find(name);
  334. if (sampler != sampler_alt_name_map.end()) {
  335. samplers.push_back(sampler->second);
  336. }
  337. }
  338. }
  339. }
  340. return samplers;
  341. }
  342. std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars) {
  343. std::unordered_map<char, gpt_sampler_type> sampler_name_map {
  344. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K },
  345. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z },
  346. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
  347. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
  348. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
  349. { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
  350. };
  351. std::vector<gpt_sampler_type> samplers;
  352. samplers.reserve(chars.size());
  353. for (const auto & c : chars) {
  354. const auto sampler = sampler_name_map.find(c);
  355. if (sampler != sampler_name_map.end()) {
  356. samplers.push_back(sampler->second);
  357. }
  358. }
  359. return samplers;
  360. }