Просмотр исходного кода

Fix new line issue with chat template, disable template when in-prefix/suffix is set (#8203)

* preserve new line llama_chat_format_single

* disable chat template if in-prefix/suffix is set

* remove redundant change
Xuan Son Nguyen 1 год назад
Родитель
Сommit
9ef0780062
4 измененных файлов с 23 добавлено и 9 удалено
  1. 13 3
      common/common.cpp
  2. 1 0
      common/common.h
  3. 7 4
      examples/main/main.cpp
  4. 2 2
      tests/test-chat-template.cpp

+ 13 - 3
common/common.cpp

@@ -1014,16 +1014,19 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     }
     if (arg == "--in-prefix-bos") {
         params.input_prefix_bos = true;
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--in-prefix") {
         CHECK_ARG
         params.input_prefix = argv[i];
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--in-suffix") {
         CHECK_ARG
         params.input_suffix = argv[i];
+        params.enable_chat_template = false;
         return true;
     }
     if (arg == "--spm-infill") {
@@ -1406,7 +1409,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                                                         "halt generation at PROMPT, return control in interactive mode\n"
                                                                         "can be specified more than once for multiple prompts" });
     options.push_back({ "main",        "-sp,   --special",              "special tokens output enabled (default: %s)", params.special ? "true" : "false" });
-    options.push_back({ "main",        "-cnv,  --conversation",         "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" });
+    options.push_back({ "main",        "-cnv,  --conversation",         "run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)", params.conversation ? "true" : "false" });
     options.push_back({ "main infill", "-i,    --interactive",          "run in interactive mode (default: %s)", params.interactive ? "true" : "false" });
     options.push_back({ "main infill", "-if,   --interactive-first",    "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" });
     options.push_back({ "main infill", "-mli,  --multiline-input",      "allows you to write or paste multiple lines without ending each in '\\'" });
@@ -2668,12 +2671,19 @@ std::string llama_chat_format_single(const struct llama_model * model,
         const std::vector<llama_chat_msg> & past_msg,
         const llama_chat_msg & new_msg,
         bool add_ass) {
+    std::ostringstream ss;
     auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
     std::vector<llama_chat_msg> chat_new(past_msg);
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
     chat_new.push_back(new_msg);
     auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
-    auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
-    return formatted;
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
 }
 
 std::string llama_chat_format_example(const struct llama_model * model,

+ 1 - 0
common/common.h

@@ -200,6 +200,7 @@ struct gpt_params {
     std::string public_path   = "";
     std::string chat_template = "";
     std::string system_prompt = "";
+    bool enable_chat_template = true;
 
     std::vector<std::string> api_keys;
 

+ 7 - 4
examples/main/main.cpp

@@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
     std::vector<llama_token> embd_inp;
 
     {
-        auto prompt = params.conversation
+        auto prompt = (params.conversation && params.enable_chat_template)
             ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
             : params.prompt;
         if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -810,7 +810,9 @@ int main(int argc, char ** argv) {
                         is_antiprompt = true;
                     }
 
-                    chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+                    if (params.enable_chat_template) {
+                        chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+                    }
                     is_interacting = true;
                     printf("\n");
                 }
@@ -872,12 +874,13 @@ int main(int argc, char ** argv) {
                         string_process_escapes(buffer);
                     }
 
-                    std::string user_inp = params.conversation
+                    bool format_chat = params.conversation && params.enable_chat_template;
+                    std::string user_inp = format_chat
                         ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
                         : std::move(buffer);
                     // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
                     const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
-                    const auto line_inp = ::llama_tokenize(ctx, user_inp,            false, params.conversation);
+                    const auto line_inp = ::llama_tokenize(ctx, user_inp,            false, format_chat);
                     const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
 
                     LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

+ 2 - 2
tests/test-chat-template.cpp

@@ -142,9 +142,9 @@ int main(void) {
         std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
         return output;
     };
-    assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+    assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
     assert(fmt_single("llama2") == "[INST] How are you [/INST]");
-    assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+    assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
 
     return 0;