1
0
Эх сурвалжийг харах

examples : fix save-load-state + rename llama-util.h

Georgi Gerganov 2 жил өмнө
parent
commit
84ca9c2ecf

+ 41 - 31
examples/save-load-state/save-load-state.cpp

@@ -1,12 +1,9 @@
-#include <vector>
-#include <cstdio>
-#include <chrono>
-
 #include "common.h"
 #include "common.h"
 #include "llama.h"
 #include "llama.h"
-#include "llama.cpp"
 
 
-using namespace std;
+#include <vector>
+#include <cstdio>
+#include <chrono>
 
 
 int main(int argc, char ** argv) {
 int main(int argc, char ** argv) {
     gpt_params params;
     gpt_params params;
@@ -20,21 +17,25 @@ int main(int argc, char ** argv) {
         return 1;
         return 1;
     }
     }
 
 
+    if (params.n_predict < 0) {
+        params.n_predict = 16;
+    }
+
     auto lparams = llama_context_default_params();
     auto lparams = llama_context_default_params();
 
 
-    lparams.n_ctx      = params.n_ctx;
-    lparams.n_parts    = params.n_parts;
-    lparams.seed       = params.seed;
-    lparams.f16_kv     = params.memory_f16;
-    lparams.use_mmap   = params.use_mmap;
-    lparams.use_mlock  = params.use_mlock;
+    lparams.n_ctx     = params.n_ctx;
+    lparams.n_parts   = params.n_parts;
+    lparams.seed      = params.seed;
+    lparams.f16_kv    = params.memory_f16;
+    lparams.use_mmap  = params.use_mmap;
+    lparams.use_mlock = params.use_mlock;
 
 
     auto n_past = 0;
     auto n_past = 0;
-    auto last_n_tokens_data = vector<llama_token>(params.repeat_last_n, 0);
+    auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
 
 
     // init
     // init
     auto ctx = llama_init_from_file(params.model.c_str(), lparams);
     auto ctx = llama_init_from_file(params.model.c_str(), lparams);
-    auto tokens = vector<llama_token>(params.n_ctx);
+    auto tokens = std::vector<llama_token>(params.n_ctx);
     auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true);
     auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true);
 
 
     if (n_prompt_tokens < 1) {
     if (n_prompt_tokens < 1) {
@@ -43,23 +44,25 @@ int main(int argc, char ** argv) {
     }
     }
 
 
     // evaluate prompt
     // evaluate prompt
-
     llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
     llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
 
 
     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
     last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
     n_past += n_prompt_tokens;
     n_past += n_prompt_tokens;
 
 
+    const size_t state_size = llama_get_state_size(ctx);
+    uint8_t * state_mem = new uint8_t[state_size];
+
     // Save state (rng, logits, embedding and kv_cache) to file
     // Save state (rng, logits, embedding and kv_cache) to file
-    FILE *fp_write = fopen("dump_state.bin", "wb");
-    auto state_size = llama_get_state_size(ctx);
-    auto state_mem = new uint8_t[state_size];
-    llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
-    fwrite(state_mem, 1, state_size, fp_write);
-    fclose(fp_write);
+    {
+        FILE *fp_write = fopen("dump_state.bin", "wb");
+        llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
+        fwrite(state_mem, 1, state_size, fp_write);
+        fclose(fp_write);
+    }
 
 
     // save state (last tokens)
     // save state (last tokens)
-    auto last_n_tokens_data_saved = vector<llama_token>(last_n_tokens_data);
-    auto n_past_saved = n_past;
+    const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
+    const auto n_past_saved = n_past;
 
 
     // first run
     // first run
     printf("\n%s", params.prompt.c_str());
     printf("\n%s", params.prompt.c_str());
@@ -75,6 +78,7 @@ int main(int argc, char ** argv) {
         auto next_token = llama_sample_token(ctx, &candidates_p);
         auto next_token = llama_sample_token(ctx, &candidates_p);
         auto next_token_str = llama_token_to_str(ctx, next_token);
         auto next_token_str = llama_token_to_str(ctx, next_token);
         last_n_tokens_data.push_back(next_token);
         last_n_tokens_data.push_back(next_token);
+
         printf("%s", next_token_str);
         printf("%s", next_token_str);
         if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
         if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
@@ -88,18 +92,21 @@ int main(int argc, char ** argv) {
     llama_free(ctx);
     llama_free(ctx);
 
 
     // load new model
     // load new model
-
     auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
     auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
 
 
     // Load state (rng, logits, embedding and kv_cache) from file
     // Load state (rng, logits, embedding and kv_cache) from file
-    FILE *fp_read = fopen("dump_state.bin", "rb");
-    auto state_size2 = llama_get_state_size(ctx2);
-    if (state_size != state_size2) {
-        fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+    {
+        FILE *fp_read = fopen("dump_state.bin", "rb");
+        if (state_size != llama_get_state_size(ctx2)) {
+            fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+            return 1;
+        }
+        fread(state_mem, 1, state_size, fp_read);
+        llama_set_state_data(ctx2, state_mem);  // could also read directly from memory mapped file
+        fclose(fp_read);
     }
     }
-    fread(state_mem, 1, state_size, fp_read);
-    llama_set_state_data(ctx2, state_mem);  // could also read directly from memory mapped file
-    fclose(fp_read);
+
+    delete[] state_mem;
 
 
     // restore state (last tokens)
     // restore state (last tokens)
     last_n_tokens_data = last_n_tokens_data_saved;
     last_n_tokens_data = last_n_tokens_data_saved;
@@ -118,6 +125,7 @@ int main(int argc, char ** argv) {
         auto next_token = llama_sample_token(ctx2, &candidates_p);
         auto next_token = llama_sample_token(ctx2, &candidates_p);
         auto next_token_str = llama_token_to_str(ctx2, next_token);
         auto next_token_str = llama_token_to_str(ctx2, next_token);
         last_n_tokens_data.push_back(next_token);
         last_n_tokens_data.push_back(next_token);
+
         printf("%s", next_token_str);
         printf("%s", next_token_str);
         if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
         if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
             fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
@@ -125,6 +133,8 @@ int main(int argc, char ** argv) {
         }
         }
         n_past += 1;
         n_past += 1;
     }
     }
+
     printf("\n\n");
     printf("\n\n");
+
     return 0;
     return 0;
 }
 }

+ 0 - 1
llama_util.h → llama-util.h

@@ -430,5 +430,4 @@ struct llama_ctx_buffer {
 typedef llama_buffer llama_ctx_buffer;
 typedef llama_buffer llama_ctx_buffer;
 #endif
 #endif
 
 
-
 #endif
 #endif

+ 1 - 2
llama.cpp

@@ -5,7 +5,7 @@
 #include <cstdio>
 #include <cstdio>
 #endif
 #endif
 
 
-#include "llama_util.h"
+#include "llama-util.h"
 #include "llama.h"
 #include "llama.h"
 
 
 #include "ggml.h"
 #include "ggml.h"
@@ -33,7 +33,6 @@
 #define LLAMA_USE_SCRATCH
 #define LLAMA_USE_SCRATCH
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
 
 
-
 // available llama models
 // available llama models
 enum e_model {
 enum e_model {
     MODEL_UNKNOWN,
     MODEL_UNKNOWN,