beam-search.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #include "common.h"
  2. #include "llama.h"
  3. #include <cassert>
  4. #include <cinttypes>
  5. #include <cmath>
  6. #include <cstdio>
  7. #include <cstring>
  8. #include <ctime>
  9. #include <fstream>
  10. #include <iostream>
  11. #include <string>
  12. #include <vector>
  13. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  14. #include <signal.h>
  15. #include <unistd.h>
  16. #elif defined (_WIN32)
  17. #define WIN32_LEAN_AND_MEAN
  18. #ifndef NOMINMAX
  19. # define NOMINMAX
  20. #endif
  21. #include <windows.h>
  22. #include <signal.h>
  23. #endif
  24. // Used for debugging to print out beam tokens.
  25. struct ostream_beam_view {
  26. llama_context * ctx;
  27. llama_beam_view beam_view;
  28. };
  29. static std::ostream & operator<<(std::ostream & os, const ostream_beam_view & obv) {
  30. os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
  31. for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
  32. os << llama_token_to_piece(obv.ctx, obv.beam_view.tokens[i]);
  33. }
  34. return os << ')';
  35. }
  36. // Put here anything you want back in beam_search_callback().
  37. struct beam_search_callback_data {
  38. llama_context * ctx;
  39. std::vector<llama_token> response;
  40. };
  41. // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
  42. // For example, eob can be flagged due to maximum token length, stop words, etc.
  43. static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
  44. return n_tokens && llama_token_is_eog(llama_get_model(callback_data.ctx), tokens[n_tokens-1]);
  45. }
  46. // Function matching type llama_beam_search_callback_fn_t.
  47. // Custom callback example is called each time the beams lengths increase:
  48. // * Show progress by printing ',' following by number of convergent beam tokens if any.
  49. // * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
  50. // This is also called when the stop condition is met.
  51. // Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
  52. static void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
  53. auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
  54. // Mark beams as EOS as needed.
  55. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
  56. llama_beam_view& beam_view = beams_state.beam_views[i];
  57. if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
  58. beam_view.eob = true;
  59. }
  60. }
  61. printf(","); // Show progress
  62. if (const size_t n = beams_state.common_prefix_length) {
  63. callback_data.response.resize(callback_data.response.size() + n);
  64. assert(0u < beams_state.n_beams);
  65. const llama_token * tokens = beams_state.beam_views[0].tokens;
  66. std::copy(tokens, tokens + n, callback_data.response.end() - n);
  67. printf("%zu", n);
  68. }
  69. fflush(stdout);
  70. #if 1 // DEBUG: print current beams for this iteration
  71. std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
  72. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
  73. std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
  74. }
  75. #endif
  76. }
  77. int main(int argc, char ** argv)
  78. {
  79. gpt_params params;
  80. //params.n_gpu_layers = 200;
  81. //---------------------------------
  82. // Print help :
  83. //---------------------------------
  84. if ( argc < 2 || argv[1][0] == '-' )
  85. {
  86. printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
  87. return 1 ;
  88. }
  89. //---------------------------------
  90. // Load parameters :
  91. //---------------------------------
  92. params.model = argv[1];
  93. params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
  94. if ( argc > 3 )
  95. {
  96. params.prompt = argv[3];
  97. }
  98. if ( params.prompt.empty() )
  99. {
  100. params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
  101. }
  102. //---------------------------------
  103. // Init LLM :
  104. //---------------------------------
  105. llama_backend_init();
  106. llama_numa_init(params.numa);
  107. llama_model * model;
  108. llama_context * ctx;
  109. std::tie(model, ctx) = llama_init_from_gpt_params( params );
  110. if ( model == NULL )
  111. {
  112. fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
  113. return 1;
  114. }
  115. //---------------------------------
  116. // Tokenize the prompt :
  117. //---------------------------------
  118. std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
  119. const size_t max_context_size = llama_n_ctx( ctx );
  120. const size_t max_tokens_list_size = max_context_size - 4 ;
  121. if (tokens_list.size() > max_tokens_list_size)
  122. {
  123. fprintf( stderr , "%s: error: prompt too long (%zu tokens, max %zu)\n" ,
  124. __func__ , tokens_list.size() , max_tokens_list_size );
  125. return 1;
  126. }
  127. fprintf( stderr, "\n\n" );
  128. // Print the tokens from the prompt :
  129. for( auto id : tokens_list )
  130. {
  131. std::cout << llama_token_to_piece(ctx, id);
  132. }
  133. std::cout << std::flush;
  134. int n_past = 0;
  135. if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0)))
  136. {
  137. fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
  138. return 1;
  139. }
  140. n_past += tokens_list.size();
  141. beam_search_callback_data callback_data{ctx, {}};
  142. size_t const beam_width = static_cast<size_t>(params.n_beams);
  143. int const n_predict = 256;
  144. llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict);
  145. std::cout << "\n\n";
  146. for (llama_token const token_id : callback_data.response) {
  147. std::cout << llama_token_to_piece(ctx,token_id);
  148. }
  149. std::cout << std::endl;
  150. llama_free( ctx );
  151. llama_free_model( model );
  152. llama_backend_free();
  153. return 0;
  154. }