Просмотр исходного кода

Add interactive mode (#61)

* Initial work on interactive mode.

* Improve interactive mode. Make rev. prompt optional.

* Update README to explain interactive mode.

* Fix OS X build
Matvey Soloviev 2 лет назад
Родитель
Сommit
96ea727f47
4 измененных файлов с 170 добавлено и 10 удалено
  1. 23 0
      README.md
  2. 127 10
      main.cpp
  3. 14 0
      utils.cpp
  4. 6 0
      utils.h

+ 23 - 0
README.md

@@ -183,6 +183,29 @@ The number of files generated for each model is as follows:
 
 
 When running the larger models, make sure you have enough disk space to store all the intermediate files.
 When running the larger models, make sure you have enough disk space to store all the intermediate files.
 
 
+### Interactive mode
+
+If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter.
+In this mode, you can always interrupt generation by pressing Ctrl+C and enter one or more lines of text which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt which makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`.
+
+Here is an example few-shot interaction, invoked with the command
+```
+./main -m ./models/13B/ggml-model-q4_0.bin -t 8 --repeat_penalty 1.2 --temp 0.9 --top_p 0.9 -n 256 \
+                                           --color -i -r "User:" \
+                                           -p \
+"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
+
+User: Hello, Bob.
+Bob: Hello. How may I help you today?
+User: Please tell me the largest city in Europe.
+Bob: Sure. The largest city in Europe is London, the capital of the United Kingdom.
+User:"
+```
+Note the use of `--color` to distinguish between user input and generated text.
+
+![image](https://user-images.githubusercontent.com/401380/224572787-d418782f-47b2-49c4-a04e-65bfa7ad4ec0.png)
+
+
 ## Limitations
 ## Limitations
 
 
 - Not sure if my tokenizer is correct. There are a few places where we might have a mistake:
 - Not sure if my tokenizer is correct. There are a few places where we might have a mistake:

+ 127 - 10
main.cpp

@@ -11,6 +11,18 @@
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
 
 
+#include <signal.h>
+#include <unistd.h>
+
+#define ANSI_COLOR_RED     "\x1b[31m"
+#define ANSI_COLOR_GREEN   "\x1b[32m"
+#define ANSI_COLOR_YELLOW  "\x1b[33m"
+#define ANSI_COLOR_BLUE    "\x1b[34m"
+#define ANSI_COLOR_MAGENTA "\x1b[35m"
+#define ANSI_COLOR_CYAN    "\x1b[36m"
+#define ANSI_COLOR_RESET   "\x1b[0m"
+#define ANSI_BOLD          "\x1b[1m"
+
 // determine number of model parts based on the dimension
 // determine number of model parts based on the dimension
 static const std::map<int, int> LLAMA_N_PARTS = {
 static const std::map<int, int> LLAMA_N_PARTS = {
     { 4096, 1 },
     { 4096, 1 },
@@ -733,6 +745,18 @@ bool llama_eval(
     return true;
     return true;
 }
 }
 
 
+static bool is_interacting = false;
+
+void sigint_handler(int signo) {
+    if (signo == SIGINT) {
+        if (!is_interacting) {
+            is_interacting=true;
+        } else {
+            _exit(130);
+        }
+    }
+}
+
 int main(int argc, char ** argv) {
 int main(int argc, char ** argv) {
     ggml_time_init();
     ggml_time_init();
     const int64_t t_main_start_us = ggml_time_us();
     const int64_t t_main_start_us = ggml_time_us();
@@ -787,6 +811,9 @@ int main(int argc, char ** argv) {
 
 
     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
 
 
+    // tokenize the reverse prompt
+    std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);
+
     printf("\n");
     printf("\n");
     printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
     printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
     printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
     printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
@@ -794,6 +821,24 @@ int main(int argc, char ** argv) {
         printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
         printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
     }
     }
     printf("\n");
     printf("\n");
+    if (params.interactive) {
+        struct sigaction sigint_action;
+        sigint_action.sa_handler = sigint_handler;
+        sigemptyset (&sigint_action.sa_mask);
+        sigint_action.sa_flags = 0; 
+        sigaction(SIGINT, &sigint_action, NULL);
+
+        printf("%s: interactive mode on.\n", __func__);
+
+        if(antiprompt_inp.size()) {
+            printf("%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
+            printf("%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
+            for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
+                printf("%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
+            }
+            printf("\n");
+        }
+    }
     printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
     printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
     printf("\n\n");
     printf("\n\n");
 
 
@@ -807,7 +852,28 @@ int main(int argc, char ** argv) {
     std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
     std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 
 
-    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+
+    if (params.interactive) {
+        printf("== Running in interactive mode. ==\n"
+               " - Press Ctrl+C to interject at any time.\n"
+               " - Press Return to return control to LLaMa.\n"
+               " - If you want to submit another line, end your input in '\\'.\n");
+    }
+
+    int remaining_tokens = params.n_predict;
+    int input_consumed = 0;
+    bool input_noecho = false;
+
+    // prompt user immediately after the starting prompt has been loaded
+    if (params.interactive_start) {
+        is_interacting = true;
+    }
+
+    if (params.use_color) {
+        printf(ANSI_COLOR_YELLOW);
+    }
+
+    while (remaining_tokens > 0) {
         // predict
         // predict
         if (embd.size() > 0) {
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
             const int64_t t_start_us = ggml_time_us();
@@ -823,8 +889,8 @@ int main(int argc, char ** argv) {
         n_past += embd.size();
         n_past += embd.size();
         embd.clear();
         embd.clear();
 
 
-        if (i >= embd_inp.size()) {
-            // sample next token
+        if (embd_inp.size() <= input_consumed) {
+            // out of input, sample next token
             const float top_k = params.top_k;
             const float top_k = params.top_k;
             const float top_p = params.top_p;
             const float top_p = params.top_p;
             const float temp  = params.temp;
             const float temp  = params.temp;
@@ -847,24 +913,74 @@ int main(int argc, char ** argv) {
 
 
             // add it to the context
             // add it to the context
             embd.push_back(id);
             embd.push_back(id);
+
+            // echo this to console
+            input_noecho = false;
+
+            // decrement remaining sampling budget
+            --remaining_tokens;
         } else {
         } else {
             // if here, it means we are still processing the input prompt
             // if here, it means we are still processing the input prompt
-            for (int k = i; k < embd_inp.size(); k++) {
-                embd.push_back(embd_inp[k]);
+            while (embd_inp.size() > input_consumed) {
+                embd.push_back(embd_inp[input_consumed]);
                 last_n_tokens.erase(last_n_tokens.begin());
                 last_n_tokens.erase(last_n_tokens.begin());
-                last_n_tokens.push_back(embd_inp[k]);
+                last_n_tokens.push_back(embd_inp[input_consumed]);
+                ++input_consumed;
                 if (embd.size() > params.n_batch) {
                 if (embd.size() > params.n_batch) {
                     break;
                     break;
                 }
                 }
             }
             }
-            i += embd.size() - 1;
+
+            if (params.use_color && embd_inp.size() <= input_consumed) {
+                printf(ANSI_COLOR_RESET);
+            }
         }
         }
 
 
         // display text
         // display text
-        for (auto id : embd) {
-            printf("%s", vocab.id_to_token[id].c_str());
+        if (!input_noecho) {
+            for (auto id : embd) {
+                printf("%s", vocab.id_to_token[id].c_str());
+            }
+            fflush(stdout);
+        }
+
+        // in interactive mode, and not currently processing queued inputs;
+        // check if we should prompt the user for more
+        if (params.interactive && embd_inp.size() <= input_consumed) {
+            // check for reverse prompt
+            if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
+                // reverse prompt found
+                is_interacting = true;
+            }
+            if (is_interacting) {
+                // currently being interactive 
+                bool another_line=true;
+                while (another_line) {
+                    char buf[256] = {0};
+                    int n_read;
+                    if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
+                    scanf("%255[^\n]%n%*c", buf, &n_read);
+                    if(params.use_color) printf(ANSI_COLOR_RESET);
+
+                    if (n_read > 0 && buf[n_read-1]=='\\') {
+                        another_line = true;
+                        buf[n_read-1] = '\n';
+                        buf[n_read] = 0;
+                    } else {
+                        another_line = false;
+                        buf[n_read] = '\n';
+                        buf[n_read+1] = 0;
+                    }
+
+                    std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
+                    embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+
+                    input_noecho = true; // do not echo this again
+                }
+
+                is_interacting = false;            
+            }
         }
         }
-        fflush(stdout);
 
 
         // end of text token
         // end of text token
         if (embd.back() == 2) {
         if (embd.back() == 2) {
@@ -873,6 +989,7 @@ int main(int argc, char ** argv) {
         }
         }
     }
     }
 
 
+
     // report timing
     // report timing
     {
     {
         const int64_t t_main_end_us = ggml_time_us();
         const int64_t t_main_end_us = ggml_time_us();

+ 14 - 0
utils.cpp

@@ -49,6 +49,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.n_batch = std::stoi(argv[++i]);
             params.n_batch = std::stoi(argv[++i]);
         } else if (arg == "-m" || arg == "--model") {
         } else if (arg == "-m" || arg == "--model") {
             params.model = argv[++i];
             params.model = argv[++i];
+        } else if (arg == "-i" || arg == "--interactive") {
+            params.interactive = true;
+        } else if (arg == "--interactive-start") {
+            params.interactive = true;
+            params.interactive_start = true;
+        } else if (arg == "--color") {
+            params.use_color = true;
+        } else if (arg == "-r" || arg == "--reverse-prompt") {
+            params.antiprompt = argv[++i];
         } else if (arg == "-h" || arg == "--help") {
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             gpt_print_usage(argc, argv, params);
             exit(0);
             exit(0);
@@ -67,6 +76,11 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -i, --interactive     run in interactive mode\n");
+    fprintf(stderr, "  --interactive-start   run in interactive mode and poll user input at startup\n");
+    fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n");
+    fprintf(stderr, "                        in interactive mode, poll user input upon seeing PROMPT\n");
+    fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n");
     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");

+ 6 - 0
utils.h

@@ -28,6 +28,12 @@ struct gpt_params {
 
 
     std::string model = "models/lamma-7B/ggml-model.bin"; // model path
     std::string model = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt;
     std::string prompt;
+
+    bool use_color = false; // use color to distinguish generations and inputs
+
+    bool interactive = false; // interactive mode
+    bool interactive_start = false; // reverse prompt immediately
+    std::string antiprompt = ""; // string upon seeing which more user input is prompted
 };
 };
 
 
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);