|
|
@@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|
|
: default_value;
|
|
|
}
|
|
|
|
|
|
-inline std::string format_llama2(std::vector<json> messages)
|
|
|
-{
|
|
|
- std::ostringstream output;
|
|
|
- bool is_inside_turn = false;
|
|
|
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
|
|
+inline bool verify_custom_template(const std::string & tmpl) {
|
|
|
+ llama_chat_message chat[] = {{"user", "test"}};
|
|
|
+ std::vector<char> buf(1);
|
|
|
+ int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
|
|
|
+ return res >= 0;
|
|
|
+}
|
|
|
|
|
|
- for (auto it = messages.begin(); it != messages.end(); ++it) {
|
|
|
- if (!is_inside_turn) {
|
|
|
- output << "[INST] ";
|
|
|
- }
|
|
|
- std::string role = json_value(*it, "role", std::string("user"));
|
|
|
- std::string content = json_value(*it, "content", std::string(""));
|
|
|
- if (role == "system") {
|
|
|
- output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
|
|
|
- is_inside_turn = true;
|
|
|
- } else if (role == "user") {
|
|
|
- output << content << " [/INST]";
|
|
|
- is_inside_turn = true;
|
|
|
- } else {
|
|
|
- output << " " << content << " </s>";
|
|
|
- is_inside_turn = false;
|
|
|
- }
|
|
|
+// Format given chat. If tmpl is empty, we take the template from model metadata
|
|
|
+inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages)
|
|
|
+{
|
|
|
+ size_t alloc_size = 0;
|
|
|
+ // vector holding all allocated string to be passed to llama_chat_apply_template
|
|
|
+ std::vector<std::string> str(messages.size() * 2);
|
|
|
+ std::vector<llama_chat_message> chat(messages.size());
|
|
|
+
|
|
|
+ for (size_t i = 0; i < messages.size(); ++i) {
|
|
|
+ auto &curr_msg = messages[i];
|
|
|
+ str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
|
|
|
+ str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
|
|
|
+ alloc_size += str[i*2 + 1].length();
|
|
|
+ chat[i].role = str[i*2 + 0].c_str();
|
|
|
+ chat[i].content = str[i*2 + 1].c_str();
|
|
|
}
|
|
|
|
|
|
- LOG_VERBOSE("format_llama2", {{"text", output.str()}});
|
|
|
+ const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
|
|
+ std::vector<char> buf(alloc_size * 2);
|
|
|
|
|
|
- return output.str();
|
|
|
-}
|
|
|
-
|
|
|
-inline std::string format_chatml(std::vector<json> messages)
|
|
|
-{
|
|
|
- std::ostringstream chatml_msgs;
|
|
|
+ // run the first time to get the total output length
|
|
|
+ int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
|
|
|
|
|
|
- for (auto it = messages.begin(); it != messages.end(); ++it) {
|
|
|
- chatml_msgs << "<|im_start|>"
|
|
|
- << json_value(*it, "role", std::string("user")) << '\n';
|
|
|
- chatml_msgs << json_value(*it, "content", std::string(""))
|
|
|
- << "<|im_end|>\n";
|
|
|
+ // if it turns out that our buffer is too small, we resize it
|
|
|
+ if ((size_t) res > buf.size()) {
|
|
|
+ buf.resize(res);
|
|
|
+ res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
|
|
|
}
|
|
|
|
|
|
- chatml_msgs << "<|im_start|>assistant" << '\n';
|
|
|
-
|
|
|
- LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
|
|
|
+ std::string formatted_chat(buf.data(), res);
|
|
|
+ LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
|
|
|
|
|
|
- return chatml_msgs.str();
|
|
|
+ return formatted_chat;
|
|
|
}
|
|
|
|
|
|
//
|