|
|
@@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+// trim whitespace from the beginning and end of a string
|
|
|
+static std::string trim(const std::string & str) {
|
|
|
+ size_t start = 0;
|
|
|
+ size_t end = str.size();
|
|
|
+ while (start < end && isspace(str[start])) {
|
|
|
+ start += 1;
|
|
|
+ }
|
|
|
+ while (end > start && isspace(str[end - 1])) {
|
|
|
+ end -= 1;
|
|
|
+ }
|
|
|
+ return str.substr(start, end - start);
|
|
|
+}
|
|
|
+
|
|
|
+// Simple version of "llama_apply_chat_template" that only works with strings
|
|
|
+// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
|
|
+static int32_t llama_chat_apply_template_internal(
|
|
|
+ const std::string & tmpl,
|
|
|
+ const std::vector<const llama_chat_message *> & chat,
|
|
|
+ std::string & dest, bool add_ass) {
|
|
|
+ // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
|
|
+ std::stringstream ss;
|
|
|
+ if (tmpl.find("<|im_start|>") != std::string::npos) {
|
|
|
+ // chatml template
|
|
|
+ for (auto message : chat) {
|
|
|
+ ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
|
|
+ }
|
|
|
+ if (add_ass) {
|
|
|
+ ss << "<|im_start|>assistant\n";
|
|
|
+ }
|
|
|
+ } else if (tmpl.find("[INST]") != std::string::npos) {
|
|
|
+ // llama2 template and its variants
|
|
|
+ // [variant] support system message
|
|
|
+ bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
|
|
+ // [variant] space before + after response
|
|
|
+ bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
|
|
+ // [variant] add BOS inside history
|
|
|
+ bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
|
|
|
+ // [variant] trim spaces from the input message
|
|
|
+ bool strip_message = tmpl.find("content.strip()") != std::string::npos;
|
|
|
+ // construct the prompt
|
|
|
+ bool is_inside_turn = true; // skip BOS at the beginning
|
|
|
+ ss << "[INST] ";
|
|
|
+ for (auto message : chat) {
|
|
|
+ std::string content = strip_message ? trim(message->content) : message->content;
|
|
|
+ std::string role(message->role);
|
|
|
+ if (!is_inside_turn) {
|
|
|
+ is_inside_turn = true;
|
|
|
+ ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
|
|
|
+ }
|
|
|
+ if (role == "system") {
|
|
|
+ if (support_system_message) {
|
|
|
+ ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
|
|
|
+ } else {
|
|
|
+ // if the model does not support system message, we still include it in the first message, but without <<SYS>>
|
|
|
+ ss << content << "\n";
|
|
|
+ }
|
|
|
+ } else if (role == "user") {
|
|
|
+ ss << content << " [/INST]";
|
|
|
+ } else {
|
|
|
+ ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
|
|
|
+ is_inside_turn = false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // llama2 templates seem to not care about "add_generation_prompt"
|
|
|
+ } else if (tmpl.find("<|user|>") != std::string::npos) {
|
|
|
+ // zephyr template
|
|
|
+ for (auto message : chat) {
|
|
|
+ ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
|
|
+ }
|
|
|
+ if (add_ass) {
|
|
|
+ ss << "<|assistant|>\n";
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // template not supported
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ dest = ss.str();
|
|
|
+ return dest.size();
|
|
|
+}
|
|
|
+
|
|
|
+LLAMA_API int32_t llama_chat_apply_template(
|
|
|
+ const struct llama_model * model,
|
|
|
+ const char * tmpl,
|
|
|
+ const struct llama_chat_message * chat,
|
|
|
+ size_t n_msg,
|
|
|
+ bool add_ass,
|
|
|
+ char * buf,
|
|
|
+ int32_t length) {
|
|
|
+ std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
|
|
+ if (tmpl == nullptr) {
|
|
|
+ GGML_ASSERT(model != nullptr);
|
|
|
+ // load template from model
|
|
|
+ std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
|
|
+ std::string template_key = "tokenizer.chat_template";
|
|
|
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
|
|
|
+ if (res < 0) {
|
|
|
+ // worst case: there is no information about template, we will use chatml by default
|
|
|
+ curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
|
|
|
+ } else {
|
|
|
+ curr_tmpl = std::string(model_template.data(), model_template.size());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // format the chat to string
|
|
|
+ std::vector<const llama_chat_message *> chat_vec;
|
|
|
+ chat_vec.resize(n_msg);
|
|
|
+ for (size_t i = 0; i < n_msg; i++) {
|
|
|
+ chat_vec[i] = &chat[i];
|
|
|
+ }
|
|
|
+ std::string formatted_chat;
|
|
|
+ int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
|
|
+ if (res < 0) {
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+ strncpy(buf, formatted_chat.c_str(), length);
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
|
|
struct llama_timings result = {
|
|
|
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|