فهرست منبع

Add chatml fallback for cpp `llama_chat_apply_template` (#8160)

* add chatml fallback for cpp `llama_chat_apply_template`

* remove redundant code
Xuan Son Nguyen 1 سال پیش
والد
کامیت
16791b8f0b
2فایلهای تغییر یافته به همراه20 افزوده شده و 1 حذف شده
  1. 18 1
      common/common.cpp
  2. 2 0
      common/common.h

+ 18 - 1
common/common.cpp

@@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
         const std::vector<llama_chat_msg> & msgs,
         bool add_ass) {
     int alloc_size = 0;
+    bool fallback = false; // indicate if we must fallback to default chatml
     std::vector<llama_chat_message> chat;
     for (auto & msg : msgs) {
         chat.push_back({msg.role.c_str(), msg.content.c_str()});
@@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
     // run the first time to get the total output length
     int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
 
+    // error: chat template is not supported
+    if (res < 0) {
+        if (ptr_tmpl != nullptr) {
+            // if the custom "tmpl" is not supported, we throw an error
+            // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+            throw std::runtime_error("this custom template is not supported");
+        } else {
+            // If the built-in template is not supported, we default to chatml
+            res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+            fallback = true;
+        }
+    }
+
     // if it turns out that our buffer is too small, we resize it
     if ((size_t) res > buf.size()) {
         buf.resize(res);
-        res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+        res = llama_chat_apply_template(
+            fallback ? nullptr : model,
+            fallback ? "chatml" : ptr_tmpl,
+            chat.data(), chat.size(), add_ass, buf.data(), buf.size());
     }
 
     std::string formatted_chat(buf.data(), res);

+ 2 - 0
common/common.h

@@ -380,6 +380,8 @@ struct llama_chat_msg {
 bool llama_chat_verify_template(const std::string & tmpl);
 
 // CPP wrapper for llama_chat_apply_template
+// If the built-in template is not supported, we default to chatml
+// If the custom "tmpl" is not supported, we throw an error
 std::string llama_chat_apply_template(const struct llama_model * model,
         const std::string & tmpl,
         const std::vector<llama_chat_msg> & chat,