1
0

speculative-simple.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. #include "arg.h"
  2. #include "common.h"
  3. #include "sampling.h"
  4. #include "speculative.h"
  5. #include "log.h"
  6. #include "llama.h"
  7. #include <cstdio>
  8. #include <cstring>
  9. #include <string>
  10. #include <vector>
  11. int main(int argc, char ** argv) {
  12. common_params params;
  13. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
  14. return 1;
  15. }
  16. if (params.n_predict < -1) {
  17. LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
  18. return 1;
  19. }
  20. common_init();
  21. if (params.speculative.model.path.empty()) {
  22. LOG_ERR("%s: --model-draft is required\n", __func__);
  23. return 1;
  24. }
  25. // init llama.cpp
  26. llama_backend_init();
  27. llama_numa_init(params.numa);
  28. llama_model * model_tgt = NULL;
  29. //llama_model * model_dft = NULL;
  30. llama_context * ctx_tgt = NULL;
  31. llama_context * ctx_dft = NULL;
  32. // load the target model
  33. auto llama_init_tgt = common_init_from_params(params);
  34. model_tgt = llama_init_tgt->model();
  35. ctx_tgt = llama_init_tgt->context();
  36. const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
  37. // load the draft model
  38. params.devices = params.speculative.devices;
  39. params.model = params.speculative.model;
  40. params.n_ctx = params.speculative.n_ctx;
  41. params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
  42. params.n_gpu_layers = params.speculative.n_gpu_layers;
  43. if (params.speculative.cpuparams.n_threads > 0) {
  44. params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
  45. }
  46. params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
  47. params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
  48. auto llama_init_dft = common_init_from_params(params);
  49. //model_dft = llama_init_dft->model();
  50. ctx_dft = llama_init_dft->context();
  51. if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
  52. LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
  53. }
  54. // Tokenize the prompt
  55. std::vector<llama_token> inp;
  56. inp = common_tokenize(ctx_tgt, params.prompt, true, true);
  57. if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
  58. LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
  59. return 1;
  60. }
  61. if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
  62. LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
  63. return 1;
  64. }
  65. LOG("\n\n");
  66. for (auto id : inp) {
  67. LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
  68. }
  69. // how many tokens to draft each time
  70. int n_draft = params.speculative.n_max;
  71. int n_draft_min = params.speculative.n_min;
  72. float p_min = params.speculative.p_min;
  73. int n_predict = 0;
  74. int n_drafted = 0;
  75. int n_accept = 0;
  76. // used to determine end of generation
  77. bool has_eos = false;
  78. // ================================================
  79. // everything until here is standard initialization
  80. // the relevant stuff for speculative decoding starts here
  81. const auto t_enc_start = ggml_time_us();
  82. // target model sampling context
  83. struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
  84. // eval the prompt
  85. llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
  86. // note: keep the last token separate!
  87. llama_token id_last = inp.back();
  88. // all tokens currently in the target context
  89. llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
  90. prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
  91. int n_past = inp.size() - 1;
  92. // init the speculator
  93. struct common_speculative_params params_spec;
  94. params_spec.n_draft = n_draft;
  95. params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
  96. params_spec.p_min = p_min;
  97. struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
  98. for (auto &pair : params.speculative.replacements) {
  99. common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
  100. }
  101. llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
  102. const auto t_enc_end = ggml_time_us();
  103. const auto t_dec_start = ggml_time_us();
  104. while (true) {
  105. // optionally, generate draft tokens that can be appended to the target batch
  106. //
  107. // this is the most important part of the speculation. the more probable tokens that are provided here
  108. // the better the performance will be. in theory, this computation can be performed asynchronously and even
  109. // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
  110. // from a cache or lookup tables.
  111. //
  112. llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
  113. //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
  114. // always have a token to evaluate from before - id_last
  115. common_batch_clear(batch_tgt);
  116. common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
  117. // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
  118. {
  119. // do not waste time on small drafts
  120. if (draft.size() < (size_t) n_draft_min) {
  121. draft.clear();
  122. }
  123. for (size_t i = 0; i < draft.size(); ++i) {
  124. common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
  125. }
  126. //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
  127. llama_decode(ctx_tgt, batch_tgt);
  128. }
  129. // sample from the full target batch and return the accepted tokens based on the target sampler
  130. //
  131. // for each token to be accepted, the sampler would have to sample that same token
  132. // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
  133. // available logits from the batch and sample the next token until we run out of logits or the sampler
  134. // disagrees with the draft
  135. //
  136. const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
  137. //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
  138. GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
  139. n_past += ids.size() - 1;
  140. n_drafted += draft.size(); // note: we ignore the discarded small drafts
  141. n_accept += ids.size() - 1;
  142. n_predict += ids.size();
  143. // process the accepted tokens and update contexts
  144. //
  145. // this is the standard token post-processing that we normally do
  146. // in this case, we do it for a group of accepted tokens at once
  147. //
  148. for (size_t i = 0; i < ids.size(); ++i) {
  149. prompt_tgt.push_back(id_last);
  150. id_last = ids[i];
  151. if (llama_vocab_is_eog(vocab, id_last)) {
  152. has_eos = true;
  153. break;
  154. }
  155. const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
  156. if (params.use_color && i + 1 < ids.size()) {
  157. LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
  158. } else {
  159. LOG("%s", token_str.c_str());
  160. }
  161. }
  162. LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
  163. {
  164. LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
  165. llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
  166. }
  167. if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
  168. break;
  169. }
  170. }
  171. auto t_dec_end = ggml_time_us();
  172. const int n_input = inp.size();
  173. LOG("\n\n");
  174. LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
  175. LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
  176. LOG_INF("\n");
  177. LOG_INF("n_draft = %d\n", n_draft);
  178. LOG_INF("n_predict = %d\n", n_predict);
  179. LOG_INF("n_drafted = %d\n", n_drafted);
  180. LOG_INF("n_accept = %d\n", n_accept);
  181. LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
  182. LOG_INF("\n");
  183. LOG_INF("draft:\n\n");
  184. llama_perf_context_print(ctx_dft);
  185. LOG_INF("\n");
  186. LOG_INF("target:\n\n");
  187. common_perf_print(ctx_tgt, smpl);
  188. llama_batch_free(batch_tgt);
  189. common_sampler_free(smpl);
  190. common_speculative_free(spec);
  191. llama_backend_free();
  192. LOG("\n\n");
  193. return 0;
  194. }