Browse Source

add alias for chat template (#5858)

Xuan Son Nguyen 1 year ago
parent
commit
4ffcdce2ff
2 changed files with 7 additions and 7 deletions
  1. 1 1
      examples/server/server.cpp
  2. 6 6
      llama.cpp

+ 1 - 1
examples/server/server.cpp

@@ -413,7 +413,7 @@ struct llama_server_context
         int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
         if (res < 0) {
             LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
-            sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
+            sparams.chat_template = "chatml";
         }
     }
 

+ 6 - 6
llama.cpp

@@ -13282,7 +13282,7 @@ static int32_t llama_chat_apply_template_internal(
     std::string & dest, bool add_ass) {
     // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
     std::stringstream ss;
-    if (tmpl.find("<|im_start|>") != std::string::npos) {
+    if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
         // chatml template
         for (auto message : chat) {
             ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@@ -13290,7 +13290,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl.find("[INST]") != std::string::npos) {
+    } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
         // llama2 template and its variants
         // [variant] support system message
         bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
@@ -13325,7 +13325,7 @@ static int32_t llama_chat_apply_template_internal(
             }
         }
         // llama2 templates seem to not care about "add_generation_prompt"
-    } else if (tmpl.find("<|user|>") != std::string::npos) {
+    } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
         // zephyr template
         for (auto message : chat) {
             ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@@ -13333,7 +13333,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|assistant|>\n";
         }
-    } else if (tmpl.find("bos_token + message['role']") != std::string::npos) {
+    } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
         // mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
         for (auto message : chat) {
             std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@@ -13342,7 +13342,7 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<s>assistant\n";
         }
-    } else if (tmpl.find("<start_of_turn>") != std::string::npos) {
+    } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
         // google/gemma-7b-it
         std::string system_prompt = "";
         for (auto message : chat) {
@@ -13389,7 +13389,7 @@ LLAMA_API int32_t llama_chat_apply_template(
         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
         if (res < 0) {
             // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
+            curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
         } else {
             curr_tmpl = std::string(model_template.data(), model_template.size());
         }