Просмотр исходного кода

main : print total token count and tokens consumed so far (#4874)

* Token count changes

* Add show token count

* Updating before PR

* Two requested changes

* Move param def posn
pudepiedj 2 лет назад
Родитель
Сommit
43f76bf1c3
4 измененных файлов с 15 добавлено и 3 удалено
  1. 8 0
      common/common.cpp
  2. 1 1
      common/common.h
  3. 5 1
      examples/main/main.cpp
  4. 1 1
      llama.cpp

+ 8 - 0
common/common.cpp

@@ -630,6 +630,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.ppl_stride = std::stoi(argv[i]);
+        } else if (arg == "-stc" || arg == "--show_token_count") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.token_interval = std::stoi(argv[i]);
         } else if (arg == "--ppl-output-type") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -944,6 +950,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --override-kv KEY=TYPE:VALUE\n");
     printf("                        advanced option to override model metadata by key. may be specified multiple times.\n");
     printf("                        types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
+    printf("  -stc N --show_token_count N\n");
+    printf("                        show consumed tokens every N tokens\n");
     printf("\n");
 #ifndef LOG_DISABLE_LOGS
     log_print_usage();

+ 1 - 1
common/common.h

@@ -64,6 +64,7 @@ struct gpt_params {
     int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
     int32_t grp_attn_n                      = 1;     // group-attention factor
     int32_t grp_attn_w                      = 512;   // group-attention width
+    int32_t token_interval                  = 512;   // show token count every 512 tokens
     float   rope_freq_base                  = 0.0f;  // RoPE base frequency
     float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
     float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
@@ -242,4 +243,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
 
 // Dump the KV cache view showing individual sequences in each cell (long output).
 void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
-

+ 5 - 1
examples/main/main.cpp

@@ -500,7 +500,7 @@ int main(int argc, char ** argv) {
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict
         if (!embd.empty()) {
-            // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
+            // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
             // --prompt or --file which uses the same value.
             int max_embd_size = n_ctx - 4;
 
@@ -650,6 +650,10 @@ int main(int argc, char ** argv) {
                 n_past += n_eval;
 
                 LOG("n_past = %d\n", n_past);
+                // Display total tokens alongside total time
+                if (n_past % params.token_interval == 0) {
+                    printf("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
+                }
             }
 
             if (!embd.empty() && !path_session.empty()) {

+ 1 - 1
llama.cpp

@@ -10921,7 +10921,7 @@ void llama_print_timings(struct llama_context * ctx) {
             __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
     LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
             __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms));
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
 }
 
 void llama_reset_timings(struct llama_context * ctx) {