simple.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #ifndef _GNU_SOURCE
  2. #define _GNU_SOURCE
  3. #endif
  4. #include "common.h"
  5. #include "llama.h"
  6. #include "build-info.h"
  7. #include <cassert>
  8. #include <cinttypes>
  9. #include <cmath>
  10. #include <cstdio>
  11. #include <cstring>
  12. #include <ctime>
  13. #include <fstream>
  14. #include <iostream>
  15. #include <string>
  16. #include <vector>
  17. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  18. #include <signal.h>
  19. #include <unistd.h>
  20. #elif defined (_WIN32)
  21. #define WIN32_LEAN_AND_MEAN
  22. #define NOMINMAX
  23. #include <windows.h>
  24. #include <signal.h>
  25. #endif
  26. int main(int argc, char ** argv)
  27. {
  28. gpt_params params;
  29. //---------------------------------
  30. // Print help :
  31. //---------------------------------
  32. if ( argc == 1 || argv[1][0] == '-' )
  33. {
  34. printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
  35. return 1 ;
  36. }
  37. //---------------------------------
  38. // Load parameters :
  39. //---------------------------------
  40. if ( argc >= 2 )
  41. {
  42. params.model = argv[1];
  43. }
  44. if ( argc >= 3 )
  45. {
  46. params.prompt = argv[2];
  47. }
  48. if ( params.prompt.empty() )
  49. {
  50. params.prompt = "Hello my name is";
  51. }
  52. //---------------------------------
  53. // Init LLM :
  54. //---------------------------------
  55. llama_init_backend();
  56. llama_context * ctx ;
  57. ctx = llama_init_from_gpt_params( params );
  58. if ( ctx == NULL )
  59. {
  60. fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
  61. return 1;
  62. }
  63. //---------------------------------
  64. // Tokenize the prompt :
  65. //---------------------------------
  66. std::vector<llama_token> tokens_list;
  67. tokens_list = ::llama_tokenize( ctx , params.prompt , true );
  68. const int max_context_size = llama_n_ctx( ctx );
  69. const int max_tokens_list_size = max_context_size - 4 ;
  70. if ( (int)tokens_list.size() > max_tokens_list_size )
  71. {
  72. fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" ,
  73. __func__ , (int)tokens_list.size() , max_tokens_list_size );
  74. return 1;
  75. }
  76. fprintf( stderr, "\n\n" );
  77. // Print the tokens from the prompt :
  78. for( auto id : tokens_list )
  79. {
  80. printf( "%s" , llama_token_to_str( ctx , id ) );
  81. }
  82. fflush(stdout);
  83. //---------------------------------
  84. // Main prediction loop :
  85. //---------------------------------
  86. // The LLM keeps a contextual cache memory of previous token evaluation.
  87. // Usually, once this cache is full, it is required to recompute a compressed context based on previous
  88. // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
  89. // example, we will just stop the loop once this cache is full or once an end of stream is detected.
  90. while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
  91. {
  92. //---------------------------------
  93. // Evaluate the tokens :
  94. //---------------------------------
  95. if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
  96. {
  97. fprintf( stderr, "%s : failed to eval\n" , __func__ );
  98. return 1;
  99. }
  100. tokens_list.clear();
  101. //---------------------------------
  102. // Select the best prediction :
  103. //---------------------------------
  104. llama_token new_token_id = 0;
  105. auto logits = llama_get_logits( ctx );
  106. auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)
  107. std::vector<llama_token_data> candidates;
  108. candidates.reserve( n_vocab );
  109. for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
  110. {
  111. candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
  112. }
  113. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  114. // Select it using the "Greedy sampling" method :
  115. new_token_id = llama_sample_token_greedy( ctx , &candidates_p );
  116. // is it an end of stream ?
  117. if ( new_token_id == llama_token_eos() )
  118. {
  119. fprintf(stderr, " [end of text]\n");
  120. break;
  121. }
  122. // Print the new token :
  123. printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
  124. fflush( stdout );
  125. // Push this new token for next evaluation :
  126. tokens_list.push_back( new_token_id );
  127. } // wend of main loop
  128. llama_free( ctx );
  129. return 0;
  130. }
  131. // EOF