test-backend-sampler.cpp 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #include "llama-cpp.h"
  4. #include "get-model.h"
  5. #include "common.h"
  6. #ifdef NDEBUG
  7. #undef NDEBUG
  8. #endif
  9. #include <algorithm>
  10. #include <cstdlib>
  11. #include <cstring>
  12. #include <iostream>
  13. #include <fstream>
  14. #include <map>
  15. #include <string>
  16. #include <unordered_map>
  17. #include <vector>
  18. struct backend_cli_args {
  19. const char * model = nullptr;
  20. const char * test = nullptr;
  21. const char * device = "cpu";
  22. };
  23. struct test_model_context {
  24. llama_model_ptr model;
  25. llama_context_ptr ctx;
  26. int n_vocab = 0;
  27. std::unordered_map<llama_seq_id, int32_t> seq_positions;
  28. std::unordered_map<llama_seq_id, int32_t> last_batch_info;
  29. bool load_model(const backend_cli_args & args) {
  30. if (model) {
  31. return true;
  32. }
  33. llama_backend_init();
  34. auto mparams = llama_model_default_params();
  35. ggml_backend_dev_t devs[2];
  36. if (std::string_view(args.device) == "gpu") {
  37. ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
  38. if (gpu == nullptr) {
  39. fprintf(stderr, "Error: GPU requested but not available\n");
  40. return false;
  41. }
  42. devs[0] = gpu;
  43. devs[1] = nullptr; // null terminator
  44. mparams.devices = devs;
  45. mparams.n_gpu_layers = 999;
  46. } else if (std::string_view(args.device) == "cpu") {
  47. ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
  48. devs[0] = cpu;
  49. devs[1] = nullptr; // null terminator
  50. mparams.devices = devs;
  51. }
  52. fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
  53. model.reset(llama_model_load_from_file(args.model, mparams));
  54. if (!model) {
  55. fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model);
  56. return false;
  57. }
  58. n_vocab = llama_vocab_n_tokens(get_vocab());
  59. fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
  60. return true;
  61. }
  62. bool setup(const backend_cli_args & args, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
  63. if (!model) {
  64. load_model(args);
  65. }
  66. if (ctx) {
  67. return true;
  68. }
  69. llama_context_params cparams = llama_context_default_params();
  70. cparams.n_ctx = 512;
  71. cparams.n_batch = 512;
  72. cparams.samplers = configs.data();
  73. cparams.n_samplers = configs.size();
  74. // If n_seq_max is not specified, calculate it from configs
  75. if (n_seq_max < 0) {
  76. int32_t max_seq_id = 0;
  77. for (const auto & config : configs) {
  78. max_seq_id = std::max(config.seq_id, max_seq_id);
  79. }
  80. cparams.n_seq_max = max_seq_id + 1;
  81. } else {
  82. cparams.n_seq_max = n_seq_max;
  83. }
  84. ctx.reset(llama_init_from_model(model.get(), cparams));
  85. if (!ctx) {
  86. fprintf(stderr, "Warning: failed to create context, skipping test\n");
  87. return false;
  88. }
  89. llama_set_warmup(ctx.get(), false);
  90. return true;
  91. }
  92. bool decode(const std::map<llama_seq_id, std::string> & prompts) {
  93. if (!ctx) {
  94. fprintf(stderr, "Error: context not initialized, call setup() first\n");
  95. return false;
  96. }
  97. last_batch_info.clear();
  98. llama_batch batch = llama_batch_init(512, 0, prompts.size());
  99. auto vocab = get_vocab();
  100. for (const auto & [seq_id, prompt] : prompts) {
  101. std::vector<llama_token> tokens;
  102. tokens.push_back(llama_vocab_bos(vocab));
  103. std::vector<llama_token> prompt_tokens(32);
  104. int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
  105. prompt_tokens.data(), prompt_tokens.size(),
  106. false, false);
  107. if (n_tokens < 0) {
  108. fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
  109. llama_batch_free(batch);
  110. return false;
  111. }
  112. for (int i = 0; i < n_tokens; i++) {
  113. tokens.push_back(prompt_tokens[i]);
  114. }
  115. if (seq_positions.find(seq_id) == seq_positions.end()) {
  116. seq_positions[seq_id] = 0;
  117. }
  118. int32_t start_pos = seq_positions[seq_id];
  119. for (size_t i = 0; i < tokens.size(); i++) {
  120. common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1);
  121. }
  122. seq_positions[seq_id] = start_pos + tokens.size();
  123. }
  124. printf("Batch contents:\n");
  125. printf("n_tokens: %d\n", batch.n_tokens);
  126. for (int i = 0; i < batch.n_tokens; i++) {
  127. printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
  128. for (int j = 0; j < batch.n_seq_id[i]; j++) {
  129. printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
  130. }
  131. printf("], logits=%d\n", batch.logits[i]);
  132. }
  133. if (llama_decode(ctx.get(), batch) != 0) {
  134. fprintf(stderr, "Warning: llama_decode failed\n");
  135. llama_batch_free(batch);
  136. return false;
  137. }
  138. // Build mapping from seq id to batch token idx
  139. for (int i = 0; i < batch.n_tokens; i++) {
  140. if (batch.logits[i]) {
  141. llama_seq_id seq_id = batch.seq_id[i][0];
  142. last_batch_info[seq_id] = i;
  143. }
  144. }
  145. llama_batch_free(batch);
  146. return true;
  147. }
  148. int32_t idx_for_seq(llama_seq_id seq_id) {
  149. auto it = last_batch_info.find(seq_id);
  150. if (it == last_batch_info.end()) {
  151. fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id);
  152. return -1;
  153. }
  154. return it->second;
  155. }
  156. void update_batch_info(const llama_batch & batch) {
  157. last_batch_info.clear();
  158. for (int i = 0; i < batch.n_tokens; i++) {
  159. if (batch.logits[i]) {
  160. llama_seq_id cur_seq = batch.seq_id[i][0];
  161. last_batch_info[cur_seq] = i;
  162. }
  163. }
  164. }
  165. bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
  166. if (ctx == nullptr) {
  167. fprintf(stderr, "Error: context not initialized, call setup() first\n");
  168. return false;
  169. }
  170. llama_batch batch = llama_batch_init(1, 0, 1);
  171. int32_t pos = seq_positions[seq_id];
  172. common_batch_add(batch, token, pos, { seq_id }, true);
  173. if (llama_decode(ctx.get(), batch) != 0) {
  174. fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
  175. llama_batch_free(batch);
  176. return false;
  177. }
  178. update_batch_info(batch);
  179. seq_positions[seq_id]++;
  180. llama_batch_free(batch);
  181. return true;
  182. }
  183. bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
  184. if (ctx == nullptr) {
  185. fprintf(stderr, "Error: context not initialized, call setup() first\n");
  186. return false;
  187. }
  188. llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
  189. for (const auto & [seq_id, token] : seq_tokens) {
  190. int32_t pos = seq_positions[seq_id];
  191. common_batch_add(batch, token, pos, { seq_id }, true);
  192. }
  193. if (llama_decode(ctx.get(), batch) != 0) {
  194. fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
  195. llama_batch_free(batch);
  196. return false;
  197. }
  198. for (const auto & [seq_id, _] : seq_tokens) {
  199. seq_positions[seq_id]++;
  200. }
  201. update_batch_info(batch);
  202. llama_batch_free(batch);
  203. return true;
  204. }
  205. std::string token_to_piece(llama_token token, bool special) {
  206. std::string piece;
  207. piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
  208. const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
  209. if (n_chars < 0) {
  210. piece.resize(-n_chars);
  211. int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special);
  212. GGML_ASSERT(check == -n_chars);
  213. }
  214. else {
  215. piece.resize(n_chars);
  216. }
  217. return piece;
  218. }
  219. void reset() {
  220. ctx.reset();
  221. seq_positions.clear();
  222. last_batch_info.clear();
  223. }
  224. const llama_vocab * get_vocab() const {
  225. return model ? llama_model_get_vocab(model.get()) : nullptr;
  226. }
  227. };
  228. static void test_backend_greedy_sampling(const backend_cli_args & args) {
  229. test_model_context test_ctx;
  230. const int seq_id = 0;
  231. struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
  232. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params));
  233. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
  234. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  235. if (!test_ctx.setup(args, backend_sampler_configs)) {
  236. return;
  237. }
  238. if (!test_ctx.decode({{seq_id, "Some"}})) {
  239. GGML_ASSERT(false && "Failed to decode token");
  240. }
  241. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  242. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  243. printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  244. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  245. token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
  246. printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  247. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  248. for (int i = 0; i < 10; i++) {
  249. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  250. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx);
  251. printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  252. if (!test_ctx.decode_token(token, 0)) {
  253. GGML_ASSERT(false && "Failed to decode token");
  254. }
  255. }
  256. }
  257. static void test_backend_top_k_sampling(const backend_cli_args & args) {
  258. test_model_context test_ctx;
  259. const int seq_id = 0;
  260. const int32_t k = 8;
  261. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  262. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  263. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
  264. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  265. if (!test_ctx.setup(args, backend_sampler_configs)) {
  266. return;
  267. }
  268. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  269. GGML_ASSERT(false && "Failed to decode token");
  270. }
  271. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  272. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  273. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  274. for (size_t i = 0; i < n_logits; ++i) {
  275. printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
  276. }
  277. llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx);
  278. uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx);
  279. for (size_t i = 0; i < n_candidates; ++i) {
  280. printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
  281. test_ctx.token_to_piece(candidates[i], false).c_str());
  282. }
  283. // Sample using CPU sampler for verification that it is possible to do hybrid
  284. // sampling, first top_k on the backend and then dist on the CPU.
  285. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  286. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  287. GGML_ASSERT(chain->iface->backend_apply != nullptr);
  288. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  289. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  290. const std::string token_str = test_ctx.token_to_piece(token, false);
  291. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  292. printf("backend top-k hybrid sampling test PASSED\n");
  293. }
  294. static void test_backend_temp_sampling(const backend_cli_args & args) {
  295. test_model_context test_ctx;
  296. {
  297. const float temp_0 = 0.8f;
  298. struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
  299. llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0));
  300. llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0));
  301. const float temp_1 = 0.1f;
  302. struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
  303. llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1));
  304. llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1));
  305. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  306. { 0, backend_sampler_chain_0.get() },
  307. { 1, backend_sampler_chain_1.get() }
  308. };
  309. if (!test_ctx.setup(args, backend_sampler_configs)) {
  310. return;
  311. }
  312. if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) {
  313. GGML_ASSERT(false && "Failed to decode token");
  314. }
  315. // Verfify sequence 0
  316. {
  317. int32_t batch_idx = test_ctx.idx_for_seq(0);
  318. int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  319. GGML_ASSERT(n_logits == test_ctx.n_vocab);
  320. // Sample from sequence 0 using CPU sampler
  321. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  322. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  323. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  324. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  325. const std::string token_str = test_ctx.token_to_piece(token, false);
  326. printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  327. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  328. }
  329. // Verfify sequence 1
  330. {
  331. int32_t batch_idx = test_ctx.idx_for_seq(1);
  332. // Sample from sequence 1 using CPU sampler
  333. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  334. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  335. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  336. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  337. const std::string token_str = test_ctx.token_to_piece(token, false);
  338. printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  339. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  340. }
  341. }
  342. // lambda to testing non-positive temperature values.
  343. auto test_argmax_temp = [&](float temp) {
  344. printf("\nTesting temperature = %.1f\n", temp);
  345. test_ctx.reset();
  346. int seq_id = 0;
  347. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  348. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  349. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp));
  350. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  351. { seq_id, backend_sampler_chain.get() },
  352. };
  353. if (!test_ctx.setup(args, backend_sampler_configs)) {
  354. return;
  355. }
  356. if (!test_ctx.decode({{seq_id, "Once"}})) {
  357. GGML_ASSERT(false && "Failed to decode token");
  358. }
  359. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  360. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  361. GGML_ASSERT(n_logits == 1);
  362. };
  363. test_argmax_temp(0.0f);
  364. test_argmax_temp(-1.0f);
  365. printf("backend temp sampling test PASSED\n");
  366. }
  367. static void test_backend_temp_ext_sampling(const backend_cli_args & args) {
  368. test_model_context test_ctx;
  369. {
  370. int seq_id = 0;
  371. const float temp = 0.8f;
  372. const float delta = 0.5f;
  373. const float exponent = 1.5f;
  374. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  375. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  376. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
  377. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  378. { seq_id, backend_sampler_chain.get() },
  379. };
  380. if (!test_ctx.setup(args, backend_sampler_configs)) {
  381. return;
  382. }
  383. if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
  384. GGML_ASSERT(false && "Failed to decode token");
  385. }
  386. // Verify sequence 0
  387. {
  388. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  389. int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  390. GGML_ASSERT(n_logits == test_ctx.n_vocab);
  391. }
  392. }
  393. test_ctx.reset();
  394. // lambda to testing non-positive temp/delta/exponent values.
  395. auto test_argmax_temp = [&](float temp, float delta, float exponent) {
  396. printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
  397. test_ctx.reset();
  398. int seq_id = 0;
  399. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  400. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  401. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent));
  402. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  403. { seq_id, backend_sampler_chain.get() },
  404. };
  405. if (!test_ctx.setup(args, backend_sampler_configs)) {
  406. return;
  407. }
  408. if (!test_ctx.decode({{seq_id, "Once"}})) {
  409. GGML_ASSERT(false && "Failed to decode token");
  410. }
  411. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  412. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  413. if (temp <= 0.0f && delta >= 0.0f) {
  414. GGML_ASSERT(n_logits == 1);
  415. } else {
  416. GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
  417. }
  418. };
  419. test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
  420. test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
  421. test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
  422. printf("backend temp_ext sampling test PASSED\n");
  423. }
  424. static void test_backend_min_p_sampling(const backend_cli_args & args) {
  425. test_model_context test_ctx;
  426. const int seq_id = 0;
  427. const float p = 0.1;
  428. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  429. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  430. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
  431. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  432. if (!test_ctx.setup(args, backend_sampler_configs)) {
  433. return;
  434. }
  435. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  436. GGML_ASSERT(false && "Failed to decode token");
  437. }
  438. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  439. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  440. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  441. // Print the logits that are above the min-p threshold
  442. std::vector<float> filtered_logits;
  443. for (size_t i = 0; i < n_logits; ++i) {
  444. if (logits[i] > -1e9f) {
  445. filtered_logits.push_back(logits[i]);
  446. //printf("min_p logit[%zu] = %.6f\n", i, logits[i]);
  447. }
  448. }
  449. GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
  450. // Sample using CPU sampler for verification to inspect they are reasonable
  451. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  452. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  453. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
  454. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  455. const std::string token_str = test_ctx.token_to_piece(token, false);
  456. printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  457. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  458. // Decode and sampler 10 more tokens
  459. for (int i = 0; i < 10; i++) {
  460. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  461. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
  462. printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  463. if (!test_ctx.decode_token(token, 0)) {
  464. GGML_ASSERT(false && "Failed to decode token");
  465. }
  466. }
  467. printf("min-p sampling test PASSED\n");
  468. }
  469. static void test_backend_top_p_sampling(const backend_cli_args & args) {
  470. test_model_context test_ctx;
  471. const int seq_id = 0;
  472. const float p = 0.9;
  473. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  474. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  475. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
  476. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  477. if (!test_ctx.setup(args, backend_sampler_configs)) {
  478. return;
  479. }
  480. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  481. return;
  482. }
  483. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  484. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  485. uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  486. // Print the logits that are above the min-p threshold
  487. std::vector<float> filtered_logits;
  488. for (size_t i = 0; i < n_logits; ++i) {
  489. if (logits[i] > -1e9f) {
  490. filtered_logits.push_back(logits[i]);
  491. }
  492. }
  493. GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
  494. GGML_ASSERT(filtered_logits.size() > 0);
  495. // Sample using CPU sampler for verification to inspect they are reasonable
  496. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  497. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  498. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88));
  499. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  500. const std::string token_str = test_ctx.token_to_piece(token, false);
  501. printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
  502. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  503. // Decode and sampler 10 more tokens
  504. for (int i = 0; i < 10; i++) {
  505. int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
  506. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
  507. printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
  508. test_ctx.decode_token(token, 0);
  509. }
  510. printf("top-p sampling test PASSED\n");
  511. }
  512. static void test_backend_multi_sequence_sampling(const backend_cli_args & args) {
  513. test_model_context test_ctx;
  514. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  515. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  516. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy());
  517. struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
  518. llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
  519. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f));
  520. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy());
  521. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  522. { 0, sampler_chain_0.get() },
  523. { 1, sampler_chain_1.get() }
  524. };
  525. if (!test_ctx.setup(args, backend_sampler_configs)) {
  526. return;
  527. }
  528. std::map<llama_seq_id, std::string> prompts = {
  529. {0, "Hello"},
  530. {1, "Some"}
  531. };
  532. if (!test_ctx.decode(prompts)) {
  533. GGML_ASSERT(false && "Failed to decode token");
  534. }
  535. // Verfiy sequence 0
  536. {
  537. int32_t batch_idx = test_ctx.idx_for_seq(0);
  538. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  539. const std::string token_str = test_ctx.token_to_piece(token, false);
  540. printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
  541. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  542. }
  543. // Verify sequence 1
  544. {
  545. int32_t batch_idx= test_ctx.idx_for_seq(1);
  546. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  547. const std::string token_str = test_ctx.token_to_piece(token, false);
  548. printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
  549. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  550. }
  551. // Generate tokens for each sequence
  552. printf("\nMulti-sequence generation:\n");
  553. for (int step = 0; step < 4; step++) {
  554. std::map<llama_seq_id, llama_token> tokens;
  555. for (llama_seq_id seq_id : {0, 1}) {
  556. int32_t idx = test_ctx.idx_for_seq(seq_id);
  557. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx);
  558. const std::string token_str = test_ctx.token_to_piece(token, false);
  559. printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
  560. tokens[seq_id] = token;
  561. }
  562. // Decode all tokens in a single batch
  563. if (!test_ctx.decode_tokens(tokens)) {
  564. GGML_ASSERT(false && "Failed to decode token");
  565. }
  566. }
  567. printf("backend multi-sequence sampling test PASSED\n");
  568. }
  569. static void test_backend_dist_sampling(const backend_cli_args & args) {
  570. test_model_context test_ctx;
  571. const int seq_id = 189;
  572. const int32_t seed = 88;
  573. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  574. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  575. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  576. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  577. if (!test_ctx.setup(args, backend_sampler_configs)) {
  578. return;
  579. }
  580. if (!test_ctx.decode({{seq_id, "Some"}})) {
  581. GGML_ASSERT(false && "Failed to decode token");
  582. }
  583. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  584. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  585. printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  586. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  587. //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
  588. token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1);
  589. printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
  590. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  591. printf("backend dist sampling test PASSED\n");
  592. }
  593. static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) {
  594. test_model_context test_ctx;
  595. const int seq_id = 0;
  596. const int32_t seed = 88;
  597. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  598. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  599. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  600. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  601. if (!test_ctx.setup(args, backend_sampler_configs)) {
  602. return;
  603. }
  604. if (!test_ctx.decode({{seq_id, "Some"}})) {
  605. GGML_ASSERT(false && "Failed to decode token");
  606. }
  607. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  608. // Sample using CPU sampler
  609. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  610. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  611. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  612. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  613. llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  614. printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
  615. GGML_ASSERT(backend_token == cpu_token);
  616. printf("backend dist & cpu sampling test PASSED\n");
  617. }
  618. static void test_backend_logit_bias_sampling(const backend_cli_args & args) {
  619. test_model_context test_ctx;
  620. // Calling load_model to ensure vocab is loaded and can be accessed
  621. if (!test_ctx.load_model(args)) {
  622. return;
  623. }
  624. const int seq_id = 0;
  625. // Create the logit biases vector.
  626. std::vector<llama_logit_bias> logit_bias;
  627. // Get the token for the piece "World".
  628. const std::string piece = "World";
  629. std::vector<llama_token> tokens(16);
  630. llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
  631. llama_token bias_token = tokens[0];
  632. logit_bias.push_back({ bias_token, +100.0f });
  633. printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
  634. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  635. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  636. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias(
  637. llama_vocab_n_tokens(test_ctx.get_vocab()),
  638. logit_bias.size(),
  639. logit_bias.data()));
  640. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88));
  641. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  642. { seq_id, backend_sampler_chain.get() },
  643. };
  644. if (!test_ctx.setup(args, backend_sampler_configs)) {
  645. return;
  646. }
  647. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  648. GGML_ASSERT(false && "Failed to decode token");
  649. }
  650. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
  651. const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
  652. printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
  653. GGML_ASSERT(backend_token == bias_token);
  654. printf("backend logit bias sampling test PASSED\n");
  655. }
  656. // This test verifies that it is possible to have two different backend sampler,
  657. // one that uses the backend dist sampler, and another that uses CPU dist sampler.
  658. static void test_backend_mixed_sampling(const backend_cli_args & args) {
  659. test_model_context test_ctx;
  660. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  661. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  662. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
  663. int k = 40;
  664. struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
  665. llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1));
  666. llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k));
  667. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  668. { 0, sampler_chain_0.get() },
  669. { 1, sampler_chain_1.get() }
  670. };
  671. if (!test_ctx.setup(args, backend_sampler_configs)) {
  672. return;
  673. }
  674. std::map<llama_seq_id, std::string> prompts = {
  675. {0, "Hello"},
  676. {1, "Some"}
  677. };
  678. if (!test_ctx.decode(prompts)) {
  679. GGML_ASSERT(false && "Failed to decode token");
  680. }
  681. // Verfiy sequence 0 that used the dist backend sampler.
  682. {
  683. int32_t batch_idx = test_ctx.idx_for_seq(0);
  684. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  685. const std::string token_str = test_ctx.token_to_piece(token, false);
  686. printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
  687. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  688. //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr);
  689. //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
  690. }
  691. // Verfiy sequence 1 that used the top-k backend sampler.
  692. {
  693. int32_t batch_idx = test_ctx.idx_for_seq(1);
  694. float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
  695. GGML_ASSERT(logits != nullptr);
  696. size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
  697. GGML_ASSERT(n_logits == (size_t) k);
  698. GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL);
  699. }
  700. printf("backend mixed sampling test PASSED\n");
  701. }
  702. static void test_backend_set_sampler(const backend_cli_args & args) {
  703. test_model_context test_ctx;
  704. const int32_t seed = 88;
  705. const int seq_id = 0;
  706. struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  707. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  708. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  709. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  710. if (!test_ctx.setup(args, backend_sampler_configs)) {
  711. return;
  712. }
  713. if (!test_ctx.decode({{seq_id, "Hello"}})) {
  714. GGML_ASSERT(false && "Failed to decode token");
  715. }
  716. int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
  717. // Sample using backend sampler configured above
  718. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  719. const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
  720. printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
  721. // Now clear the backend sampler for this sequence.
  722. llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr);
  723. printf("Cleared backend sampler for seq_id %d\n", seq_id);
  724. // Sample using CPU sampler
  725. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  726. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  727. llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18));
  728. std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
  729. if (!test_ctx.decode_tokens(tokens)) {
  730. GGML_ASSERT(false && "Failed to decode token");
  731. }
  732. // Should not have any sampled token or probs after clearing the backend sampler.
  733. const int32_t idx = test_ctx.idx_for_seq(seq_id);
  734. GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL);
  735. GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr);
  736. // Sample the token using the CPU sampler chain.
  737. llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id);
  738. const std::string token2_str = test_ctx.token_to_piece(token2, false);
  739. printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
  740. std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
  741. // Set a new backend sampler for the sequence.
  742. struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
  743. llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params));
  744. llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20));
  745. llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  746. llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get());
  747. if (!test_ctx.decode_tokens(tokens2)) {
  748. GGML_ASSERT(false && "Failed to decode token");
  749. }
  750. llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id));
  751. const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
  752. printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
  753. printf("backend set sampler test PASSED\n");
  754. }
  755. static void test_backend_cpu_mixed_batch(const backend_cli_args & args) {
  756. test_model_context test_ctx;
  757. // Sequence 0 uses backend sampling
  758. struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
  759. llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0));
  760. llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88));
  761. std::vector<llama_sampler_seq_config> backend_sampler_configs = {
  762. { 0, sampler_chain_0.get() },
  763. };
  764. // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
  765. if (!test_ctx.setup(args, backend_sampler_configs, 2)) {
  766. return;
  767. }
  768. std::map<llama_seq_id, std::string> prompts = {
  769. {0, "Hello"}, // Will use backend sampling
  770. {1, "Some"} // Will use CPU sampling
  771. };
  772. if (!test_ctx.decode(prompts)) {
  773. GGML_ASSERT(false && "Failed to decode token");
  774. }
  775. // Verify sequence 0 (backend sampled)
  776. {
  777. int32_t batch_idx = test_ctx.idx_for_seq(0);
  778. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  779. const std::string token_str = test_ctx.token_to_piece(token, false);
  780. printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
  781. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  782. }
  783. // Verify sequence 1 (CPU sampled)
  784. {
  785. int32_t batch_idx = test_ctx.idx_for_seq(1);
  786. llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  787. GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
  788. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  789. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  790. llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
  791. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  792. const std::string token_str = test_ctx.token_to_piece(token, false);
  793. printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
  794. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  795. }
  796. // Clear/remove the backend sampler, and sample again
  797. {
  798. // clear the backend sampler for seq 0 so that there are no backend
  799. // samplers.
  800. llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
  801. // Create a CPU sampler and verify we can sampler from it.
  802. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  803. llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
  804. llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
  805. int32_t batch_idx = test_ctx.idx_for_seq(1);
  806. llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx);
  807. if (!test_ctx.decode_token(token, 1)) {
  808. GGML_ASSERT(false && "Failed to decode token");
  809. }
  810. }
  811. // Set a backend sampler so that we can verify that it can be reset
  812. {
  813. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
  814. llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params));
  815. llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88));
  816. llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get());
  817. if (!test_ctx.decode_token(3834, 0)) {
  818. GGML_ASSERT(false && "Failed to decode token");
  819. }
  820. int32_t batch_idx = test_ctx.idx_for_seq(0);
  821. llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
  822. const std::string token_str = test_ctx.token_to_piece(token, false);
  823. printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
  824. GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
  825. }
  826. printf("backend-cpu mixed batch test PASSED\n");
  827. }
  828. static void test_backend_max_outputs(const backend_cli_args & args) {
  829. test_model_context test_ctx;
  830. const int seq_id = 0;
  831. const int32_t seed = 88;
  832. llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
  833. llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params));
  834. llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
  835. std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
  836. if (!test_ctx.setup(args, backend_sampler_configs)) {
  837. return;
  838. }
  839. llama_batch batch = llama_batch_init(512, 0, 1);
  840. std::string prompt = "Hello";
  841. std::vector<llama_token> tokens;
  842. tokens.push_back(llama_vocab_bos(test_ctx.get_vocab()));
  843. std::vector<llama_token> prompt_tokens(32);
  844. int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(),
  845. prompt_tokens.data(), prompt_tokens.size(),
  846. false, false);
  847. for (int i = 0; i < n_tokens; i++) {
  848. tokens.push_back(prompt_tokens[i]);
  849. }
  850. for (size_t i = 0; i < tokens.size(); i++) {
  851. // set all tokens as output to trigger error
  852. common_batch_add(batch, tokens[i], i, { seq_id }, true);
  853. }
  854. printf(">>> test_max_outputs expected error start:\n");
  855. const int ret = llama_decode(test_ctx.ctx.get(), batch);
  856. GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
  857. printf("<<< test_max_outputs expected error end.\n");
  858. llama_batch_free(batch);
  859. printf("backend max outputs test PASSED\n");
  860. }
  861. struct backend_test_case {
  862. const char * name;
  863. void (*fn)(const backend_cli_args &);
  864. bool enabled_by_default;
  865. };
  866. static const backend_test_case BACKEND_TESTS[] = {
  867. { "greedy", test_backend_greedy_sampling, true },
  868. { "logit_bias", test_backend_logit_bias_sampling, true },
  869. { "temp", test_backend_temp_sampling, true },
  870. { "temp_ext", test_backend_temp_ext_sampling, true },
  871. { "top_k", test_backend_top_k_sampling, true },
  872. { "multi_sequence", test_backend_multi_sequence_sampling, true },
  873. { "dist", test_backend_dist_sampling, true },
  874. { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
  875. { "set_sampler", test_backend_set_sampler, true },
  876. { "max_outputs", test_backend_max_outputs, true },
  877. { "mixed", test_backend_mixed_sampling, true },
  878. { "min_p", test_backend_min_p_sampling, true },
  879. { "cpu_mixed", test_backend_cpu_mixed_batch, true },
  880. { "top_p", test_backend_top_p_sampling, true },
  881. };
  882. static backend_cli_args parse_backend_cli(int argc, char ** argv) {
  883. backend_cli_args out;
  884. for (int i = 1; i < argc; ++i) {
  885. const char * arg = argv[i];
  886. if (std::strcmp(arg, "--test") == 0) {
  887. if (i + 1 >= argc) {
  888. fprintf(stderr, "--test expects a value\n");
  889. exit(EXIT_FAILURE);
  890. }
  891. out.test = argv[++i];
  892. continue;
  893. }
  894. if (std::strncmp(arg, "--test=", 7) == 0) {
  895. out.test = arg + 7;
  896. continue;
  897. }
  898. if (std::strcmp(arg, "--model") == 0) {
  899. if (i + 1 >= argc) {
  900. fprintf(stderr, "--model expects a value\n");
  901. exit(EXIT_FAILURE);
  902. }
  903. out.model = argv[++i];
  904. continue;
  905. }
  906. if (std::strncmp(arg, "--model=", 8) == 0) {
  907. out.model = arg + 8;
  908. continue;
  909. }
  910. if (std::strcmp(arg, "--device") == 0) {
  911. if (i + 1 >= argc) {
  912. fprintf(stderr, "--device expects a value (cpu or gpu)\n");
  913. exit(EXIT_FAILURE);
  914. }
  915. out.device = argv[++i];
  916. continue;
  917. }
  918. if (std::strncmp(arg, "--device=", 9) == 0) {
  919. out.device = arg + 9;
  920. continue;
  921. }
  922. if (!out.model) {
  923. out.model = arg;
  924. continue;
  925. }
  926. fprintf(stderr, "Unexpected argument: %s\n", arg);
  927. exit(EXIT_FAILURE);
  928. }
  929. if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) {
  930. fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device);
  931. exit(EXIT_FAILURE);
  932. }
  933. return out;
  934. }
  935. static std::vector<const backend_test_case *> collect_tests_to_run(const char * requested) {
  936. std::vector<const backend_test_case *> selected;
  937. if (requested != nullptr) {
  938. for (const auto & test : BACKEND_TESTS) {
  939. if (std::strcmp(test.name, requested) == 0) {
  940. selected.push_back(&test);
  941. break;
  942. }
  943. }
  944. if (selected.empty()) {
  945. fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested);
  946. for (const auto & test : BACKEND_TESTS) {
  947. fprintf(stderr, " %s\n", test.name);
  948. }
  949. exit(EXIT_FAILURE);
  950. }
  951. } else {
  952. for (const auto & test : BACKEND_TESTS) {
  953. if (test.enabled_by_default) {
  954. selected.push_back(&test);
  955. }
  956. }
  957. }
  958. if (selected.empty()) {
  959. fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n");
  960. }
  961. return selected;
  962. }
  963. static void run_tests(const std::vector<const backend_test_case *> & tests, const backend_cli_args & args) {
  964. for (const auto * test : tests) {
  965. fprintf(stderr, "\n=== %s ===\n", test->name);
  966. test->fn(args);
  967. }
  968. }
  969. int main(int argc, char ** argv) {
  970. backend_cli_args args = parse_backend_cli(argc, argv);
  971. if (args.model == nullptr) {
  972. args.model = get_model_or_exit(1, argv);
  973. }
  974. std::ifstream file(args.model);
  975. if (!file.is_open()) {
  976. fprintf(stderr, "no model '%s' found\n", args.model);
  977. return EXIT_FAILURE;
  978. }
  979. fprintf(stderr, "using '%s'\n", args.model);
  980. ggml_time_init();
  981. const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
  982. if (!tests.empty()) {
  983. run_tests(tests, args);
  984. }
  985. return 0;
  986. }