test-backend-sampler.cpp 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #include "llama-cpp.h"
  4. #include "get-model.h"
  5. #include "common.h"
  6. #ifdef NDEBUG
  7. #undef NDEBUG
  8. #endif
  9. #include <algorithm>
  10. #include <cstdlib>
  11. #include <cstring>
  12. #include <fstream>
  13. #include <map>
  14. #include <string>
  15. #include <unordered_map>
  16. #include <vector>
  17. struct test_args {
  18. std::string model;
  19. std::string test;
  20. std::string device = "auto";
  21. };
  22. struct test_params {
  23. llama_model_ptr model;
  24. };
  25. static llama_model_ptr load_model(const test_args & args) {
  26. auto mparams = llama_model_default_params();
  27. ggml_backend_dev_t devs[2] = { nullptr, nullptr };
  28. if (args.device != "auto") {
  29. if (args.device == "gpu") {
  30. devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
  31. if (devs[0] == nullptr) {
  32. fprintf(stderr, "Error: GPU requested but not available\n");
  33. return nullptr;
  34. }
  35. mparams.n_gpu_layers = 999;
  36. } else if (args.device == "cpu") {
  37. devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
  38. mparams.n_gpu_layers = 0;
  39. } else {
  40. fprintf(stderr, "Error: invalid device '%s'\n", args.device.c_str());
  41. return nullptr;
  42. }
  43. mparams.devices = devs;
  44. fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
  45. }
  46. llama_model_ptr res;
  47. res.reset(llama_model_load_from_file(args.model.c_str(), mparams));
  48. if (!res) {
  49. fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model.c_str());
  50. return nullptr;
  51. }
  52. return res;
  53. }
  54. struct test_context {
  55. llama_context_ptr ctx;
  56. int n_vocab = 0;
  57. const llama_vocab * vocab = nullptr;
  58. std::unordered_map<llama_seq_id, int32_t> seq_positions;
  59. std::unordered_map<llama_seq_id, int32_t> last_batch_info;
  60. test_context(const test_params & params, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
  61. auto * model = params.model.get();
  62. GGML_ASSERT(model);
  63. GGML_ASSERT(!ctx);
  64. llama_context_params cparams = llama_context_default_params();
  65. cparams.n_ctx = 512;
  66. cparams.n_batch = 512;
  67. cparams.samplers = configs.data();
  68. cparams.n_samplers = configs.size();
  69. // If n_seq_max is not specified, calculate it from configs
  70. if (n_seq_max < 0) {
  71. int32_t max_seq_id = 0;
  72. for (const auto & config : configs) {
  73. max_seq_id = std::max(config.seq_id, max_seq_id);
  74. }
  75. cparams.n_seq_max = max_seq_id + 1;
  76. } else {
  77. cparams.n_seq_max = n_seq_max;
  78. }
  79. ctx.reset(llama_init_from_model(model, cparams));
  80. if (!ctx) {
  81. throw std::runtime_error("failed to create context");
  82. }
  83. llama_set_warmup(ctx.get(), false);
  84. vocab = llama_model_get_vocab(model);
  85. n_vocab = llama_vocab_n_tokens(vocab);
  86. }
  87. bool decode(const std::map<llama_seq_id, std::string> & prompts) {
  88. GGML_ASSERT(ctx);
  89. last_batch_info.clear();
  90. llama_batch batch = llama_batch_init(512, 0, prompts.size());
  91. for (const auto & [seq_id, prompt] : prompts) {
  92. std::vector<llama_token> tokens;
  93. tokens.push_back(llama_vocab_bos(vocab));
  94. std::vector<llama_token> prompt_tokens(32);
  95. int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
  96. prompt_tokens.data(), prompt_tokens.size(),
  97. false, false);
  98. if (n_tokens < 0) {
  99. fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
  100. llama_batch_free(batch);
  101. return false;
  102. }
  103. for (int i = 0; i < n_tokens; i++) {
  104. tokens.push_back(prompt_tokens[i]);
  105. }
  106. if (seq_positions.find(seq_id) == seq_positions.end()) {
  107. seq_positions[seq_id] = 0;
  108. }
  109. int32_t start_pos = seq_positions[seq_id];
  110. for (size_t i = 0; i < tokens.size(); i++) {
  111. common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
  112. }
  113. seq_positions[seq_id] = start_pos + tokens.size();
  114. }
  115. printf("Batch contents:\n");
  116. printf("n_tokens: %d\n", batch.n_tokens);
  117. for (int i = 0; i < batch.n_tokens; i++) {
  118. printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
  119. for (int j = 0; j < batch.n_seq_id[i]; j++) {
  120. printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
  121. }
  122. printf("], logits=%d\n", batch.logits[i]);
  123. }
  124. if (llama_decode(ctx.get(), batch) != 0) {
  125. fprintf(stderr, "Warning: llama_decode failed\n");
  126. llama_batch_free(batch);
  127. return false;
  128. }
  129. // Build mapping from seq id to batch token idx
  130. for (int i = 0; i < batch.n_tokens; i++) {
  131. if (batch.logits[i]) {
  132. llama_seq_id seq_id = batch.seq_id[i][0];
  133. last_batch_info[seq_id] = i;
  134. }
  135. }
  136. llama_batch_free(batch);
  137. return true;
  138. }
  139. int32_t idx_for_seq(llama_seq_id seq_id) {
  140. auto it = last_batch_info.find(seq_id);
  141. if (it == last_batch_info.end()) {
  142. fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id);
  143. return -1;
  144. }
  145. return it->second;
  146. }
  147. void update_batch_info(const llama_batch & batch) {
  148. last_batch_info.clear();
  149. for (int i = 0; i < batch.n_tokens; i++) {
  150. if (batch.logits[i]) {
  151. llama_seq_id cur_seq = batch.seq_id[i][0];
  152. last_batch_info[cur_seq] = i;
  153. }
  154. }
  155. }
  156. bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
  157. GGML_ASSERT(ctx);
  158. llama_batch batch = llama_batch_init(1, 0, 1);
  159. int32_t pos = seq_positions[seq_id];
  160. common_batch_add(batch, token, pos, { seq_id }, true);
  161. if (llama_decode(ctx.get(), batch) != 0) {
  162. fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
  163. llama_batch_free(batch);
  164. return false;
  165. }
  166. update_batch_info(batch);
  167. seq_positions[seq_id]++;
  168. llama_batch_free(batch);
  169. return true;
  170. }
  171. bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
  172. GGML_ASSERT(ctx);
  173. llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
  174. for (const auto & [seq_id, token] : seq_tokens) {
  175. int32_t pos = seq_positions[seq_id];
  176. common_batch_add(batch, token, pos, { seq_id }, true);
  177. }
  178. if (llama_decode(ctx.get(), batch) != 0) {
  179. fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
  180. llama_batch_free(batch);
  181. return false;
  182. }
  183. for (const auto & [seq_id, _] : seq_tokens) {
  184. seq_positions[seq_id]++;
  185. }
  186. update_batch_info(batch);
  187. llama_batch_free(batch);
  188. return true;
  189. }
  190. std::string token_to_piece(llama_token token, bool special) const {
  191. std::string piece;
  192. piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
  193. const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
  194. if (n_chars < 0) {
  195. piece.resize(-n_chars);
  196. int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
  197. GGML_ASSERT(check == -n_chars);
  198. } else {
  199. piece.resize(n_chars);
  200. }
  201. return piece;
  202. }
  203. };
  204. static void test_backend_greedy_sampling(const test_params & params) {
  205. const int seq_id = 0;
  206. struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
  207. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params));
  208. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
  209. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  210. test_context test_ctx(params, backend_sampler_configs);
  211. if (!test_ctx.decode({{seq_id, "Some"}})) {
  212. GGML_ASSERT(false && "Failed to decode token");
  213. }
  214. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  215. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  216. printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  217. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  218. token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
  219. printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  220. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  221. for (int i = 0; i < 10; i++) {
  222. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  223. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
  224. printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  225. if (!test_ctx.decode_token(token, 0)) {
  226. GGML_ASSERT(false && "Failed to decode token");
  227. }
  228. }
  229. }
  230. static void test_backend_top_k_sampling(const test_params & params) {
  231. const int seq_id = 0;
  232. const int32_t k = 8;
  233. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  234. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  235. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
  236. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  237. test_context test_ctx(params, backend_sampler_configs);
  238. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  239. GGML_ASSERT(false && "Failed to decode token");
  240. }
  241. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  242. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  243. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  244. for (size_t i = 0; i < n_logits; ++i) {
  245. printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
  246. }
  247. llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
  248. uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
  249. for (size_t i = 0; i < n_candidates; ++i) {
  250. printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
  251. test_ctx.token_to_piece(candidates[i], false).c_str());
  252. }
  253. // Sample using CPU sampler for verification that it is possible to do hybrid
  254. // sampling, first top_k on the backend and then dist on the CPU.
  255. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  256. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  257. GGML_ASSERT(chain->iface->backend_apply != nullptr);
  258. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  259. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  260. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  261. printf("backend top-k hybrid sampling test PASSED\n");
  262. }
  263. static void test_backend_temp_sampling(const test_params & params) {
  264. {
  265. const float temp_0 = 0.8f;
  266. struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
  267. llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0));
  268. llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0));
  269. const float temp_1 = 0.1f;
  270. struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
  271. llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1));
  272. llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1));
  273. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  274. { 0, backend_sampler_chain_0.get() },
  275. { 1, backend_sampler_chain_1.get() }
  276. };
  277. test_context test_ctx(params, backend_sampler_configs);
  278. if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
  279. GGML_ASSERT(false && "Failed to decode token");
  280. }
  281. // Verfify sequence 0
  282. {
  283. int32_t batch_idx = test_ctx.idx_for_seq(0);
  284. int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  285. GGML_ASSERT(n_logits == test_ctx.n_vocab);
  286. // Sample from sequence 0 using CPU sampler
  287. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  288. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  289. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  290. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  291. const std::string token_str = test_ctx.token_to_piece(token, false);
  292. printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  293. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  294. }
  295. // Verfify sequence 1
  296. {
  297. int32_t batch_idx = test_ctx.idx_for_seq(1);
  298. // Sample from sequence 1 using CPU sampler
  299. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  300. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  301. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  302. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  303. const std::string token_str = test_ctx.token_to_piece(token, false);
  304. printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  305. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  306. }
  307. }
  308. // lambda to testing non-positive temperature values.
  309. auto test_argmax_temp = [&](float temp) {
  310. printf("\nTesting temperature = %.1f\n", temp);
  311. int seq_id = 0;
  312. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  313. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  314. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp));
  315. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  316. { seq_id, backend_sampler_chain.get() },
  317. };
  318. test_context test_ctx(params, backend_sampler_configs);
  319. if (!test_ctx.decode({{seq_id, "Once"}})) {
  320. GGML_ASSERT(false && "Failed to decode token");
  321. }
  322. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  323. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  324. GGML_ASSERT(n_logits == 1);
  325. };
  326. test_argmax_temp(0.0f);
  327. test_argmax_temp(-1.0f);
  328. printf("backend temp sampling test PASSED\n");
  329. }
  330. static void test_backend_temp_ext_sampling(const test_params & params) {
  331. {
  332. int seq_id = 0;
  333. const float temp = 0.8f;
  334. const float delta = 0.5f;
  335. const float exponent = 1.5f;
  336. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  337. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  338. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
  339. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  340. { seq_id, backend_sampler_chain.get() },
  341. };
  342. test_context test_ctx(params, backend_sampler_configs);
  343. if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
  344. GGML_ASSERT(false && "Failed to decode token");
  345. }
  346. // Verify sequence 0
  347. {
  348. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  349. int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  350. GGML_ASSERT(n_logits == test_ctx.n_vocab);
  351. }
  352. }
  353. // lambda to testing non-positive temp/delta/exponent values.
  354. auto test_argmax_temp = [&](float temp, float delta, float exponent) {
  355. printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
  356. int seq_id = 0;
  357. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  358. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  359. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
  360. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  361. { seq_id, backend_sampler_chain.get() },
  362. };
  363. test_context test_ctx(params, backend_sampler_configs);
  364. if (!test_ctx.decode({{seq_id, "Once"}})) {
  365. GGML_ASSERT(false && "Failed to decode token");
  366. }
  367. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  368. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  369. if (temp <= 0.0f && delta >= 0.0f) {
  370. GGML_ASSERT(n_logits == 1);
  371. } else {
  372. GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
  373. }
  374. };
  375. test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
  376. test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
  377. test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
  378. printf("backend temp_ext sampling test PASSED\n");
  379. }
  380. static void test_backend_min_p_sampling(const test_params & params) {
  381. const int seq_id = 0;
  382. const float p = 0.1;
  383. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  384. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  385. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
  386. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  387. test_context test_ctx(params, backend_sampler_configs);
  388. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  389. GGML_ASSERT(false && "Failed to decode token");
  390. }
  391. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  392. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  393. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  394. // Print the logits that are above the min-p threshold
  395. std::vector<float> filtered_logits;
  396. for (size_t i = 0; i < n_logits; ++i) {
  397. if (logits[i] > -1e9f) {
  398. filtered_logits.push_back(logits[i]);
  399. //printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
  400. }
  401. }
  402. GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
  403. // Sample using CPU sampler for verification to inspect they are reasonable
  404. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  405. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  406. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
  407. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  408. const std::string token_str = test_ctx.token_to_piece(token, false);
  409. printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  410. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  411. // Decode and sampler 10 more tokens
  412. for (int i = 0; i < 10; i++) {
  413. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  414. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
  415. printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  416. if (!test_ctx.decode_token(token, 0)) {
  417. GGML_ASSERT(false && "Failed to decode token");
  418. }
  419. }
  420. printf("min-p sampling test PASSED\n");
  421. }
  422. static void test_backend_top_p_sampling(const test_params & params) {
  423. const int seq_id = 0;
  424. const float p = 0.9;
  425. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  426. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  427. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
  428. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  429. test_context test_ctx(params, backend_sampler_configs);
  430. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  431. return;
  432. }
  433. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  434. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  435. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  436. // Print the logits that are above the min-p threshold
  437. std::vector<float> filtered_logits;
  438. for (size_t i = 0; i < n_logits; ++i) {
  439. if (logits[i] > -1e9f) {
  440. filtered_logits.push_back(logits[i]);
  441. }
  442. }
  443. GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
  444. GGML_ASSERT(filtered_logits.size() > 0);
  445. // Sample using CPU sampler for verification to inspect they are reasonable
  446. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  447. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  448. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
  449. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  450. const std::string token_str = test_ctx.token_to_piece(token, false);
  451. printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  452. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  453. // Decode and sampler 10 more tokens
  454. for (int i = 0; i < 10; i++) {
  455. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  456. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
  457. printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  458. test_ctx.decode_token(token, 0);
  459. }
  460. printf("top-p sampling test PASSED\n");
  461. }
  462. static void test_backend_multi_sequence_sampling(const test_params & params) {
  463. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  464. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  465. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
  466. struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
  467. llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
  468. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f));
  469. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy());
  470. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  471. { 0, sampler_chain_0.get() },
  472. { 1, sampler_chain_1.get() }
  473. };
  474. test_context test_ctx(params, backend_sampler_configs);
  475. std::map<llama_seq_id, std::string> prompts = {
  476. {0, "Hello"},
  477. {1, "Some"}
  478. };
  479. if (!test_ctx.decode(prompts)) {
  480. GGML_ASSERT(false && "Failed to decode token");
  481. }
  482. // Verfiy sequence 0
  483. {
  484. int32_t batch_idx = test_ctx.idx_for_seq(0);
  485. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  486. const std::string token_str = test_ctx.token_to_piece(token, false);
  487. printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
  488. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  489. }
  490. // Verify sequence 1
  491. {
  492. int32_t batch_idx= test_ctx.idx_for_seq(1);
  493. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  494. const std::string token_str = test_ctx.token_to_piece(token, false);
  495. printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
  496. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  497. }
  498. // Generate tokens for each sequence
  499. printf("\nMulti-sequence generation:\n");
  500. for (int step = 0; step < 4; step++) {
  501. std::map<llama_seq_id, llama_token> tokens;
  502. for (llama_seq_id seq_id : {0, 1}) {
  503. int32_t idx = test_ctx.idx_for_seq(seq_id);
  504. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
  505. const std::string token_str = test_ctx.token_to_piece(token, false);
  506. printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
  507. tokens[seq_id] = token;
  508. }
  509. // Decode all tokens in a single batch
  510. if (!test_ctx.decode_tokens(tokens)) {
  511. GGML_ASSERT(false && "Failed to decode token");
  512. }
  513. }
  514. printf("backend multi-sequence sampling test PASSED\n");
  515. }
  516. static void test_backend_dist_sampling(const test_params & params) {
  517. const int seq_id = 189;
  518. const int32_t seed = 88;
  519. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  520. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  521. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  522. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  523. test_context test_ctx(params, backend_sampler_configs);
  524. if (!test_ctx.decode({{seq_id, "Some"}})) {
  525. GGML_ASSERT(false && "Failed to decode token");
  526. }
  527. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  528. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  529. printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  530. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  531. //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
  532. token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
  533. printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  534. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  535. printf("backend dist sampling test PASSED\n");
  536. }
  537. static void test_backend_dist_sampling_and_cpu(const test_params & params) {
  538. const int seq_id = 0;
  539. const int32_t seed = 88;
  540. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  541. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  542. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  543. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  544. test_context test_ctx(params, backend_sampler_configs);
  545. if (!test_ctx.decode({{seq_id, "Some"}})) {
  546. GGML_ASSERT(false && "Failed to decode token");
  547. }
  548. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  549. // Sample using CPU sampler
  550. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  551. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  552. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  553. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  554. llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  555. printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
  556. GGML_ASSERT(backend_token == cpu_token);
  557. printf("backend dist & cpu sampling test PASSED\n");
  558. }
  559. static void test_backend_logit_bias_sampling(const test_params & params) {
  560. const auto * model = params.model.get();
  561. const auto * vocab = llama_model_get_vocab(model);
  562. const int seq_id = 0;
  563. std::vector<llama_logit_bias> logit_bias;
  564. // Get the token for the piece "World".
  565. const std::string piece = "World";
  566. std::vector<llama_token> tokens(16);
  567. llama_tokenize(vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
  568. llama_token bias_token = tokens[0];
  569. // TODO: biasing too much here makes the Vulkan sampling fail - should be investigated further
  570. // https://github.com/ggml-org/llama.cpp/actions/runs/20894267644/job/60030252675?pr=18753#step:3:23350
  571. //logit_bias.push_back({ bias_token, +100.0f });
  572. logit_bias.push_back({ bias_token, +10.0f });
  573. printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
  574. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  575. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  576. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
  577. llama_vocab_n_tokens(vocab),
  578. logit_bias.size(),
  579. logit_bias.data()));
  580. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
  581. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  582. { seq_id, backend_sampler_chain.get() },
  583. };
  584. test_context test_ctx(params, backend_sampler_configs);
  585. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  586. GGML_ASSERT(false && "Failed to decode token");
  587. }
  588. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
  589. printf("sampled token = %d, expected = %d\n", backend_token, bias_token);
  590. GGML_ASSERT(backend_token == bias_token);
  591. printf("backend logit bias sampling test PASSED\n");
  592. }
  593. // This test verifies that it is possible to have two different backend sampler,
  594. // one that uses the backend dist sampler, and another that uses CPU dist sampler.
  595. static void test_backend_mixed_sampling(const test_params & params) {
  596. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  597. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  598. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
  599. int k = 40;
  600. struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
  601. llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
  602. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k));
  603. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  604. { 0, sampler_chain_0.get() },
  605. { 1, sampler_chain_1.get() }
  606. };
  607. test_context test_ctx(params, backend_sampler_configs);
  608. std::map<llama_seq_id, std::string> prompts = {
  609. {0, "Hello"},
  610. {1, "Some"}
  611. };
  612. if (!test_ctx.decode(prompts)) {
  613. GGML_ASSERT(false && "Failed to decode token");
  614. }
  615. // Verfiy sequence 0 that used the dist backend sampler.
  616. {
  617. int32_t batch_idx = test_ctx.idx_for_seq(0);
  618. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  619. const std::string token_str = test_ctx.token_to_piece(token, false);
  620. printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
  621. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  622. //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
  623. //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
  624. }
  625. // Verfiy sequence 1 that used the top-k backend sampler.
  626. {
  627. int32_t batch_idx = test_ctx.idx_for_seq(1);
  628. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  629. GGML_ASSERT(logits != nullptr);
  630. size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  631. GGML_ASSERT(n_logits == (size_t) k);
  632. GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
  633. }
  634. printf("backend mixed sampling test PASSED\n");
  635. }
  636. static void test_backend_set_sampler(const test_params & params) {
  637. const int seq_id = 0;
  638. const int32_t seed = 88;
  639. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  640. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  641. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  642. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  643. test_context test_ctx(params, backend_sampler_configs);
  644. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  645. GGML_ASSERT(false && "Failed to decode token");
  646. }
  647. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  648. // Sample using backend sampler configured above
  649. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  650. const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
  651. printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
  652. // Now clear the backend sampler for this sequence.
  653. llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
  654. printf("Cleared backend sampler for seq_id %d\n", seq_id);
  655. // Sample using CPU sampler
  656. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  657. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  658. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  659. std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
  660. if (!test_ctx.decode_tokens(tokens)) {
  661. GGML_ASSERT(false && "Failed to decode token");
  662. }
  663. // Should not have any sampled token or probs after clearing the backend sampler.
  664. const int32_t idx = test_ctx.idx_for_seq(seq_id);
  665. GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
  666. GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
  667. // Sample the token using the CPU sampler chain.
  668. llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
  669. const std::string token2_str = test_ctx.token_to_piece(token2, false);
  670. printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
  671. std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
  672. // Set a new backend sampler for the sequence.
  673. struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
  674. llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
  675. llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
  676. llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  677. llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
  678. if (!test_ctx.decode_tokens(tokens2)) {
  679. GGML_ASSERT(false && "Failed to decode token");
  680. }
  681. llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
  682. const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
  683. printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
  684. printf("backend set sampler test PASSED\n");
  685. }
  686. static void test_backend_cpu_mixed_batch(const test_params & params) {
  687. // Sequence 0 uses backend sampling
  688. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  689. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  690. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
  691. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  692. { 0, sampler_chain_0.get() },
  693. };
  694. // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
  695. test_context test_ctx(params, backend_sampler_configs, 2);
  696. std::map<llama_seq_id, std::string> prompts = {
  697. {0, "Hello"}, // Will use backend sampling
  698. {1, "Some"} // Will use CPU sampling
  699. };
  700. if (!test_ctx.decode(prompts)) {
  701. GGML_ASSERT(false && "Failed to decode token");
  702. }
  703. // Verify sequence 0 (backend sampled)
  704. {
  705. int32_t batch_idx = test_ctx.idx_for_seq(0);
  706. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  707. const std::string token_str = test_ctx.token_to_piece(token, false);
  708. printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
  709. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  710. }
  711. // Verify sequence 1 (CPU sampled)
  712. {
  713. int32_t batch_idx = test_ctx.idx_for_seq(1);
  714. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  715. GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
  716. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  717. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  718. llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
  719. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  720. const std::string token_str = test_ctx.token_to_piece(token, false);
  721. printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
  722. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  723. }
  724. // Clear/remove the backend sampler, and sample again
  725. {
  726. // clear the backend sampler for seq 0 so that there are no backend
  727. // samplers.
  728. llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
  729. // Create a CPU sampler and verify we can sampler from it.
  730. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  731. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  732. llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
  733. int32_t batch_idx = test_ctx.idx_for_seq(1);
  734. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  735. if (!test_ctx.decode_token(token, 1)) {
  736. GGML_ASSERT(false && "Failed to decode token");
  737. }
  738. }
  739. // Set a backend sampler so that we can verify that it can be reset
  740. {
  741. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  742. llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
  743. llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
  744. llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
  745. if (!test_ctx.decode_token(3834, 0)) {
  746. GGML_ASSERT(false && "Failed to decode token");
  747. }
  748. int32_t batch_idx = test_ctx.idx_for_seq(0);
  749. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  750. const std::string token_str = test_ctx.token_to_piece(token, false);
  751. printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
  752. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  753. }
  754. printf("backend-cpu mixed batch test PASSED\n");
  755. }
  756. static void test_backend_max_outputs(const test_params & params) {
  757. const int seq_id = 0;
  758. const int32_t seed = 88;
  759. llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  760. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  761. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  762. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  763. test_context test_ctx(params, backend_sampler_configs);
  764. llama_batch batch = llama_batch_init(512, 0, 1);
  765. std::string prompt = "Hello";
  766. std::vector<llama_token> tokens;
  767. tokens.push_back(llama_vocab_bos(test_ctx.vocab));
  768. std::vector<llama_token> prompt_tokens(32);
  769. int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(),
  770. prompt_tokens.data(), prompt_tokens.size(),
  771. false, false);
  772. for (int i = 0; i < n_tokens; i++) {
  773. tokens.push_back(prompt_tokens[i]);
  774. }
  775. for (size_t i = 0; i < tokens.size(); i++) {
  776. // set all tokens as output to trigger error
  777. common_batch_add(batch, tokens[i], i, { seq_id }, true);
  778. }
  779. printf(">>> test_max_outputs expected error start:\n");
  780. const int ret = llama_decode(test_ctx.ctx.get(), batch);
  781. GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
  782. printf("<<< test_max_outputs expected error end.\n");
  783. llama_batch_free(batch);
  784. printf("backend max outputs test PASSED\n");
  785. }
  786. struct backend_test_case {
  787. std::string name;
  788. void (*fn)(const test_params &);
  789. bool enabled_by_default;
  790. };
  791. static const backend_test_case BACKEND_TESTS[] = {
  792. { "greedy", test_backend_greedy_sampling, true },
  793. { "logit_bias", test_backend_logit_bias_sampling, true },
  794. { "temp", test_backend_temp_sampling, true },
  795. { "temp_ext", test_backend_temp_ext_sampling, true },
  796. { "top_k", test_backend_top_k_sampling, true },
  797. { "multi_sequence", test_backend_multi_sequence_sampling, true },
  798. { "dist", test_backend_dist_sampling, true },
  799. { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
  800. { "set_sampler", test_backend_set_sampler, true },
  801. { "max_outputs", test_backend_max_outputs, true },
  802. { "mixed", test_backend_mixed_sampling, true },
  803. { "min_p", test_backend_min_p_sampling, true },
  804. { "cpu_mixed", test_backend_cpu_mixed_batch, true },
  805. { "top_p", test_backend_top_p_sampling, true },
  806. };
  807. static test_args parse_cli(int argc, char ** argv) {
  808. test_args out;
  809. for (int i = 1; i < argc; ++i) {
  810. const char * arg = argv[i];
  811. if (std::strcmp(arg, "--test") == 0) {
  812. if (i + 1 >= argc) {
  813. fprintf(stderr, "--test expects a value\n");
  814. exit(EXIT_FAILURE);
  815. }
  816. out.test = argv[++i];
  817. continue;
  818. }
  819. if (std::strncmp(arg, "--test=", 7) == 0) {
  820. out.test = arg + 7;
  821. continue;
  822. }
  823. if (std::strcmp(arg, "--model") == 0) {
  824. if (i + 1 >= argc) {
  825. fprintf(stderr, "--model expects a value\n");
  826. exit(EXIT_FAILURE);
  827. }
  828. out.model = argv[++i];
  829. continue;
  830. }
  831. if (std::strncmp(arg, "--model=", 8) == 0) {
  832. out.model = arg + 8;
  833. continue;
  834. }
  835. if (std::strcmp(arg, "--device") == 0) {
  836. if (i + 1 >= argc) {
  837. fprintf(stderr, "--device expects a value (cpu or gpu)\n");
  838. exit(EXIT_FAILURE);
  839. }
  840. out.device = argv[++i];
  841. continue;
  842. }
  843. if (std::strncmp(arg, "--device=", 9) == 0) {
  844. out.device = arg + 9;
  845. continue;
  846. }
  847. if (out.model.empty()) {
  848. out.model = arg;
  849. continue;
  850. }
  851. fprintf(stderr, "Unexpected argument: %s\n", arg);
  852. exit(EXIT_FAILURE);
  853. }
  854. if (out.device != "cpu" && out.device != "gpu" && out.device != "auto") {
  855. fprintf(stderr, "Invalid device '%s'. Must be 'cpu', 'gpu' or 'auto'\n", out.device.c_str());
  856. exit(EXIT_FAILURE);
  857. }
  858. return out;
  859. }
  860. static std::vector<const backend_test_case *> collect_tests_to_run(const std::string & requested) {
  861. std::vector<const backend_test_case *> selected;
  862. if (!requested.empty()) {
  863. for (const auto & test : BACKEND_TESTS) {
  864. if (test.name == requested) {
  865. selected.push_back(&test);
  866. break;
  867. }
  868. }
  869. if (selected.empty()) {
  870. fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested.c_str());
  871. for (const auto & test : BACKEND_TESTS) {
  872. fprintf(stderr, " %s\n", test.name.c_str());
  873. }
  874. exit(EXIT_FAILURE);
  875. }
  876. } else {
  877. for (const auto & test : BACKEND_TESTS) {
  878. if (test.enabled_by_default) {
  879. selected.push_back(&test);
  880. }
  881. }
  882. }
  883. if (selected.empty()) {
  884. fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n");
  885. }
  886. return selected;
  887. }
  888. static void run_tests(const std::vector<const backend_test_case *> & tests, const test_params & args) {
  889. for (const auto & test : tests) {
  890. fprintf(stderr, "\n=== %s ===\n", test->name.c_str());
  891. try {
  892. test->fn(args);
  893. } catch (const std::exception & e) {
  894. fprintf(stderr, "Error running test '%s': %s\n", test->name.c_str(), e.what());
  895. exit(EXIT_FAILURE);
  896. }
  897. }
  898. }
  899. int main(int argc, char ** argv) {
  900. test_args args = parse_cli(argc, argv);
  901. if (args.model.empty()) {
  902. args.model = get_model_or_exit(1, argv);
  903. }
  904. {
  905. std::ifstream file(args.model);
  906. if (!file.is_open()) {
  907. fprintf(stderr, "no model '%s' found\n", args.model.c_str());
  908. return EXIT_FAILURE;
  909. }
  910. }
  911. fprintf(stderr, "using '%s'\n", args.model.c_str());
  912. llama_backend_init();
  913. test_params params = {
  914. /*.model =*/ load_model(args),
  915. };
  916. const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
  917. if (!tests.empty()) {
  918. run_tests(tests, params);
  919. }
  920. return 0;
  921. }