lookahead.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. #include "common.h"
  2. #include "llama.h"
  3. #include <cmath>
  4. #include <cstdio>
  5. #include <string>
  6. #include <vector>
  7. struct ngram_data {
  8. bool active = false;
  9. llama_seq_id seq_id = -1;
  10. std::vector<int> i_batch;
  11. std::vector<llama_token> tokens;
  12. };
  13. // n-gram container
  14. struct ngram_container {
  15. ngram_container(int n_vocab, int N, int G) {
  16. cnt.resize(n_vocab);
  17. head.resize(n_vocab);
  18. tokens.resize(n_vocab * G * (N - 1));
  19. }
  20. int n_total = 0;
  21. std::vector<int> cnt;
  22. std::vector<int> head;
  23. // [n_vocab][G][N - 1]
  24. // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
  25. std::vector<llama_token> tokens;
  26. };
  27. int main(int argc, char ** argv) {
  28. gpt_params params;
  29. if (gpt_params_parse(argc, argv, params) == false) {
  30. return 1;
  31. }
  32. const int W = 15; // lookahead window
  33. const int N = 5; // n-gram size
  34. const int G = 15; // max verification n-grams
  35. const bool dump_kv_cache = params.dump_kv_cache;
  36. #ifndef LOG_DISABLE_LOGS
  37. log_set_target(log_filename_generator("lookahead", "log"));
  38. LOG_TEE("Log start\n");
  39. log_dump_cmdline(argc, argv);
  40. #endif // LOG_DISABLE_LOGS
  41. // init llama.cpp
  42. llama_backend_init();
  43. llama_numa_init(params.numa);
  44. llama_model * model = NULL;
  45. llama_context * ctx = NULL;
  46. // load the target model
  47. std::tie(model, ctx) = llama_init_from_gpt_params(params);
  48. // Tokenize the prompt
  49. const bool add_bos = llama_should_add_bos_token(model);
  50. LOG("add_bos tgt: %d\n", add_bos);
  51. std::vector<llama_token> inp;
  52. std::vector<llama_token> all;
  53. inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
  54. all = inp;
  55. const int max_context_size = llama_n_ctx(ctx);
  56. const int max_tokens_list_size = max_context_size - 4;
  57. if ((int) inp.size() > max_tokens_list_size) {
  58. fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
  59. return 1;
  60. }
  61. fprintf(stderr, "\n\n");
  62. for (auto id : inp) {
  63. fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
  64. }
  65. fflush(stderr);
  66. const int n_input = inp.size();
  67. const auto t_enc_start = ggml_time_us();
  68. // eval the prompt
  69. llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
  70. llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
  71. for (int s = 1; s < W + G + 1; ++s) {
  72. llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
  73. }
  74. const auto t_enc_end = ggml_time_us();
  75. int n_predict = 0;
  76. int n_accept = 0;
  77. int n_past = inp.size();
  78. llama_token id = 0;
  79. // used to determine end of generation
  80. bool has_eos = false;
  81. // for each decoded batch, we have at most W + G + 1 distinct sequences:
  82. // seq_id == 0 : the current input token
  83. // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
  84. // seq_id [W + 1, W + G] : verification n-grams
  85. llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
  86. // target model sampling context
  87. struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
  88. // verification n-grams
  89. std::vector<ngram_data> ngrams_cur(G);
  90. // tokens for the past N - 1 Jacobi iterations
  91. std::vector<llama_token> tokens_j_prev(W);
  92. std::vector<std::vector<llama_token>> tokens_j(N - 1);
  93. for (int j = 0; j < N - 1; j++) {
  94. tokens_j[j].resize(W);
  95. for (int i = 0; i < W; i++) {
  96. // there are different ways to init these tokens
  97. if (0) {
  98. // initialize randomly from the prompt tokens
  99. tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
  100. } else {
  101. // initialize with a sequence of increasing numbers
  102. tokens_j[j][i] = 100 + i;
  103. }
  104. }
  105. }
  106. std::vector<llama_seq_id> seq_id_look;
  107. // the input token belongs both to all sequences
  108. std::vector<llama_seq_id> seq_id_all(W + G + 1);
  109. for (int i = 0; i < W + G + 1; i++) {
  110. seq_id_all[i] = i;
  111. }
  112. // here we keep adding new n-grams as we go
  113. ngram_container ngrams_observed(llama_n_vocab(model), N, G);
  114. // debug
  115. struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
  116. const auto t_dec_start = ggml_time_us();
  117. // sample first token
  118. {
  119. id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
  120. llama_sampling_accept(ctx_sampling, ctx, id, true);
  121. {
  122. const std::string token_str = llama_token_to_piece(ctx, id);
  123. printf("%s", token_str.c_str());
  124. fflush(stdout);
  125. }
  126. }
  127. while (true) {
  128. // debug
  129. if (dump_kv_cache) {
  130. llama_kv_cache_view_update(ctx, &kvc_view);
  131. dump_kv_cache_view_seqs(kvc_view, 40);
  132. }
  133. // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
  134. //
  135. // Example for W = 5, N = 4, G = 2:
  136. // (I = input, L = lookahead, V = verification)
  137. //
  138. // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
  139. // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0
  140. // Info: I L L L L L L L L L L L L L L V V V V V V
  141. // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past)
  142. // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
  143. // ---------------------------------------------------------------------
  144. // Seq: 0
  145. // 1 1 1
  146. // 2 2 2 2
  147. // 3 3 3 3 3
  148. // 4 4 4 4 4 4
  149. // 5 5 5 5 5 5 5
  150. // 6 6 6 6
  151. // 7 7 7 7
  152. // ---------------------------------------------------------------------
  153. // | | | | | | | | | | |
  154. // V V V V V | | | | | |
  155. // j_tokens | | | | | |
  156. // V V V V V V
  157. // id
  158. {
  159. llama_batch_clear(batch);
  160. // current token - first token of the first level
  161. llama_batch_add(batch, id, n_past, seq_id_all, true);
  162. // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
  163. {
  164. const int g_cur = ngrams_observed.cnt[id];
  165. ngrams_cur.resize(g_cur);
  166. for (int g = 0; g < g_cur; g++) {
  167. ngrams_cur[g].active = true;
  168. ngrams_cur[g].tokens.resize(N);
  169. ngrams_cur[g].i_batch.resize(N);
  170. ngrams_cur[g].seq_id = W + 1 + g;
  171. ngrams_cur[g].i_batch[0] = 0;
  172. ngrams_cur[g].tokens [0] = id;
  173. }
  174. for (int j = 0; j < N - 1; j++) {
  175. for (int g = 0; g < g_cur; g++) {
  176. const int idx = id*(N - 1)*G + g*(N - 1);
  177. const llama_token t = ngrams_observed.tokens[idx + j];
  178. ngrams_cur[g].tokens [j + 1] = t;
  179. ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
  180. llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
  181. }
  182. }
  183. }
  184. // fill the remaining W - 1 tokens for the first level
  185. for (int i = 1; i < W; i++) {
  186. seq_id_look.resize(W - i);
  187. for (int j = 0; j < W - i; j++) {
  188. seq_id_look[j] = i + j + 1;
  189. }
  190. llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
  191. }
  192. // fill the rest of the levels
  193. for (int j = 1; j < N - 1; j++) {
  194. for (int i = 0; i < W; i++) {
  195. llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
  196. }
  197. }
  198. }
  199. if (llama_decode(ctx, batch) != 0) {
  200. fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__);
  201. return 1;
  202. }
  203. int seq_id_best = 0;
  204. for (int v = 0; v < N; ++v) {
  205. int i_batch = 0;
  206. // if no active ngrams are left, it means the sampled token does not pass the verification
  207. if (v > 0) {
  208. for (int g = 0; g < (int) ngrams_cur.size(); g++) {
  209. if (ngrams_cur[g].active) {
  210. i_batch = ngrams_cur[g].i_batch[v];
  211. seq_id_best = ngrams_cur[g].seq_id;
  212. ++n_accept;
  213. break;
  214. }
  215. }
  216. // no more matches -> create a new batch
  217. if (i_batch == 0) {
  218. break;
  219. }
  220. }
  221. // sample the next token
  222. id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
  223. llama_sampling_accept(ctx_sampling, ctx, id, true);
  224. // print
  225. {
  226. const std::string token_str = llama_token_to_piece(ctx, id);
  227. if (v == 0) {
  228. printf("%s", token_str.c_str());
  229. } else {
  230. // print light cyan
  231. printf("\033[0;96m%s\033[0m", token_str.c_str());
  232. }
  233. fflush(stdout);
  234. if (id == llama_token_eos(model)) {
  235. has_eos = true;
  236. }
  237. all.push_back(id);
  238. }
  239. ++n_predict;
  240. ++n_past;
  241. if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
  242. break;
  243. }
  244. // verify across active n-grams
  245. for (int g = 0; g < (int) ngrams_cur.size(); g++) {
  246. if (ngrams_cur[g].active) {
  247. if (v == N - 1) {
  248. ngrams_cur[g].active = false;
  249. } else {
  250. if (id != ngrams_cur[g].tokens[v + 1]) {
  251. ngrams_cur[g].active = false;
  252. }
  253. }
  254. }
  255. }
  256. // print known n-grams starting with token id (debug)
  257. if (0 && v == 0) {
  258. if (ngrams_observed.cnt[id] > 0) {
  259. printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
  260. }
  261. for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
  262. printf(" - ngram %2d: ", i);
  263. const int idx = id*(N - 1)*G + i*(N - 1);
  264. for (int j = 0; j < N - 1; j++) {
  265. const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
  266. printf("%s", token_str.c_str());
  267. }
  268. printf("\n");
  269. }
  270. }
  271. // update lookahead tokens
  272. {
  273. for (int i = 0; i < W; i++) {
  274. tokens_j_prev[i] = tokens_j[0][i];
  275. }
  276. for (int j = 0; j < N - 2; j++) {
  277. tokens_j[j] = tokens_j[j + 1];
  278. }
  279. if (v == 0) {
  280. // sample from the last level
  281. for (int i = 0; i < W; i++) {
  282. tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
  283. }
  284. } else {
  285. for (int i = 0; i < W; i++) {
  286. // there are different ways to init these tokens
  287. if (0) {
  288. // random init
  289. tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
  290. } else {
  291. // init from the previous level
  292. tokens_j[N - 2][i] = tokens_j[0][i];
  293. }
  294. }
  295. }
  296. }
  297. // update observed ngrams
  298. if (v == 0) {
  299. // the first token of the n-gram is determined by the index in the container so it is not stored
  300. std::vector<llama_token> ngram(N - 1);
  301. // n-gram generation
  302. // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
  303. for (int f = 0; f < W; ++f) {
  304. const int ft = tokens_j_prev[f]; // first token of the n-gram
  305. for (int j = 0; j < N - 1; ++j) {
  306. ngram[j] = tokens_j[j][f];
  307. }
  308. // filter-out repeating n-grams
  309. {
  310. bool is_unique = true;
  311. for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
  312. const int idx = ft*(N - 1)*G + k*(N - 1);
  313. bool is_match = true;
  314. for (int j = 0; j < N - 1; ++j) {
  315. if (ngrams_observed.tokens[idx + j] != ngram[j]) {
  316. is_match = false;
  317. break;
  318. }
  319. }
  320. if (is_match) {
  321. is_unique = false;
  322. break;
  323. }
  324. }
  325. if (!is_unique) {
  326. continue;
  327. }
  328. }
  329. const int head = ngrams_observed.head[ft];
  330. const int idx = ft*(N - 1)*G + head*(N - 1);
  331. for (int i = 0; i < N - 1; i++) {
  332. ngrams_observed.tokens[idx + i] = ngram[i];
  333. }
  334. ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
  335. ngrams_observed.head[ft] = (head + 1) % G;
  336. ngrams_observed.n_total++;
  337. }
  338. }
  339. }
  340. if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
  341. break;
  342. }
  343. // KV cache management
  344. // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
  345. llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
  346. if (seq_id_best != 0) {
  347. // if a verification token matched, we keep the best sequence and remove the rest
  348. // this leads to some KV cache fragmentation
  349. llama_kv_cache_seq_keep(ctx, seq_id_best);
  350. llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
  351. llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
  352. for (int s = 1; s < W + G + 1; ++s) {
  353. llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
  354. }
  355. }
  356. }
  357. auto t_dec_end = ggml_time_us();
  358. LOG_TEE("\n\n");
  359. LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
  360. LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
  361. LOG_TEE("\n");
  362. LOG_TEE("W = %2d\n", W);
  363. LOG_TEE("N = %2d\n", N);
  364. LOG_TEE("G = %2d\n", G);
  365. LOG_TEE("\n");
  366. LOG_TEE("n_predict = %d\n", n_predict);
  367. LOG_TEE("n_accept = %d\n", n_accept);
  368. llama_print_timings(ctx);
  369. llama_kv_cache_view_free(&kvc_view);
  370. llama_sampling_free(ctx_sampling);
  371. llama_batch_free(batch);
  372. llama_free(ctx);
  373. llama_free_model(model);
  374. llama_backend_free();
  375. fprintf(stderr, "\n\n");
  376. return 0;
  377. }