1
0

simple.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #ifndef _GNU_SOURCE
  2. #define _GNU_SOURCE
  3. #endif
  4. #include "common.h"
  5. #include "llama.h"
  6. #include "build-info.h"
  7. #include <cassert>
  8. #include <cinttypes>
  9. #include <cmath>
  10. #include <cstdio>
  11. #include <cstring>
  12. #include <ctime>
  13. #include <fstream>
  14. #include <iostream>
  15. #include <string>
  16. #include <vector>
  17. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  18. #include <signal.h>
  19. #include <unistd.h>
  20. #elif defined (_WIN32)
  21. #define WIN32_LEAN_AND_MEAN
  22. #define NOMINMAX
  23. #include <windows.h>
  24. #include <signal.h>
  25. #endif
  26. int main(int argc, char ** argv)
  27. {
  28. gpt_params params;
  29. //---------------------------------
  30. // Print help :
  31. //---------------------------------
  32. if ( argc == 1 || argv[1][0] == '-' )
  33. {
  34. printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
  35. return 1 ;
  36. }
  37. //---------------------------------
  38. // Load parameters :
  39. //---------------------------------
  40. if ( argc >= 2 )
  41. {
  42. params.model = argv[1];
  43. }
  44. if ( argc >= 3 )
  45. {
  46. params.prompt = argv[2];
  47. }
  48. if ( params.prompt.empty() )
  49. {
  50. params.prompt = "Hello my name is";
  51. }
  52. //---------------------------------
  53. // Init LLM :
  54. //---------------------------------
  55. llama_backend_init(params.numa);
  56. llama_model * model;
  57. llama_context * ctx;
  58. std::tie(model, ctx) = llama_init_from_gpt_params( params );
  59. if ( model == NULL )
  60. {
  61. fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
  62. return 1;
  63. }
  64. //---------------------------------
  65. // Tokenize the prompt :
  66. //---------------------------------
  67. std::vector<llama_token> tokens_list;
  68. tokens_list = ::llama_tokenize( ctx , params.prompt , true );
  69. const int max_context_size = llama_n_ctx( ctx );
  70. const int max_tokens_list_size = max_context_size - 4 ;
  71. if ( (int)tokens_list.size() > max_tokens_list_size )
  72. {
  73. fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" ,
  74. __func__ , (int)tokens_list.size() , max_tokens_list_size );
  75. return 1;
  76. }
  77. fprintf( stderr, "\n\n" );
  78. // Print the tokens from the prompt :
  79. for( auto id : tokens_list )
  80. {
  81. printf( "%s" , llama_token_to_str( ctx , id ) );
  82. }
  83. fflush(stdout);
  84. //---------------------------------
  85. // Main prediction loop :
  86. //---------------------------------
  87. // The LLM keeps a contextual cache memory of previous token evaluation.
  88. // Usually, once this cache is full, it is required to recompute a compressed context based on previous
  89. // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
  90. // example, we will just stop the loop once this cache is full or once an end of stream is detected.
  91. while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
  92. {
  93. //---------------------------------
  94. // Evaluate the tokens :
  95. //---------------------------------
  96. if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
  97. {
  98. fprintf( stderr, "%s : failed to eval\n" , __func__ );
  99. return 1;
  100. }
  101. tokens_list.clear();
  102. //---------------------------------
  103. // Select the best prediction :
  104. //---------------------------------
  105. llama_token new_token_id = 0;
  106. auto logits = llama_get_logits( ctx );
  107. auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)
  108. std::vector<llama_token_data> candidates;
  109. candidates.reserve( n_vocab );
  110. for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
  111. {
  112. candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
  113. }
  114. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  115. // Select it using the "Greedy sampling" method :
  116. new_token_id = llama_sample_token_greedy( ctx , &candidates_p );
  117. // is it an end of stream ?
  118. if ( new_token_id == llama_token_eos() )
  119. {
  120. fprintf(stderr, " [end of text]\n");
  121. break;
  122. }
  123. // Print the new token :
  124. printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
  125. fflush( stdout );
  126. // Push this new token for next evaluation :
  127. tokens_list.push_back( new_token_id );
  128. } // wend of main loop
  129. llama_free( ctx );
  130. llama_free_model( model );
  131. llama_backend_free();
  132. return 0;
  133. }
  134. // EOF