Просмотр исходного кода

main : add --conversation / -cnv flag (#7108)

Dawid Potocki 1 год назад
Родитель
Сommit
83330d8cd6
3 измененных файлов с 13 добавлено и 4 удалено
  1. 5 0
      common/common.cpp
  2. 1 0
      common/common.h
  3. 7 4
      examples/main/main.cpp

+ 5 - 0
common/common.cpp

@@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.instruct = true;
         return true;
     }
+    if (arg == "-cnv" || arg == "--conversation") {
+        params.conversation = true;
+        return true;
+    }
     if (arg == "-cml" || arg == "--chatml") {
         params.chatml = true;
         return true;
@@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --version             show version and build info\n");
     printf("  -i, --interactive     run in interactive mode\n");
     printf("  --interactive-first   run in interactive mode and wait for input right away\n");
+    printf("  -cnv, --conversation  run in conversation mode (does not print special tokens and suffix/prefix)\n");
     printf("  -ins, --instruct      run in instruction mode (use with Alpaca models)\n");
     printf("  -cml, --chatml        run in chatml mode (use with ChatML-compatible models)\n");
     printf("  --multiline-input     allows you to write or paste multiple lines without ending each in '\\'\n");

+ 1 - 0
common/common.h

@@ -140,6 +140,7 @@ struct gpt_params {
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
+    bool conversation      = false; // conversation mode (does not print special tokens and suffix/prefix)
     bool chatml            = false; // chatml mode (used for models trained on chatml syntax)
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it

+ 7 - 4
examples/main/main.cpp

@@ -362,6 +362,9 @@ int main(int argc, char ** argv) {
         params.interactive_first = true;
         params.antiprompt.emplace_back("<|im_start|>user\n");
     }
+    else if (params.conversation) {
+        params.interactive_first = true;
+    }
 
     // enable interactive mode if interactive start is specified
     if (params.interactive_first) {
@@ -733,7 +736,7 @@ int main(int argc, char ** argv) {
         // display text
         if (input_echo && display) {
             for (auto id : embd) {
-                const std::string token_str = llama_token_to_piece(ctx, id);
+                const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
                 printf("%s", token_str.c_str());
 
                 if (embd.size() > 1) {
@@ -816,7 +819,7 @@ int main(int argc, char ** argv) {
             if (n_past > 0 && is_interacting) {
                 LOG("waiting for user input\n");
 
-                if (params.instruct || params.chatml) {
+                if (params.conversation || params.instruct || params.chatml) {
                     printf("\n> ");
                 }
 
@@ -826,7 +829,7 @@ int main(int argc, char ** argv) {
                 }
 
                 std::string buffer;
-                if (!params.input_prefix.empty()) {
+                if (!params.input_prefix.empty() && !params.conversation) {
                     LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
                     printf("%s", params.input_prefix.c_str());
                 }
@@ -850,7 +853,7 @@ int main(int argc, char ** argv) {
                 // Entering a empty line lets the user pass control back
                 if (buffer.length() > 1) {
                     // append input suffix if any
-                    if (!params.input_suffix.empty()) {
+                    if (!params.input_suffix.empty() && !params.conversation) {
                         LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
                         printf("%s", params.input_suffix.c_str());
                     }