Browse Source

chat : fix translategemma crash on common_chat_format_example (#19019)

Xuan-Son Nguyen 1 week ago
parent
commit
b5b8fa1c8b
1 changed files with 45 additions and 0 deletions
  1. 45 0
      common/chat.cpp

+ 45 - 0
common/chat.cpp

@@ -2650,6 +2650,45 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
     return data;
 }
 
+static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
+    common_chat_params data;
+
+    // This template does not support tools or reasoning
+    // we just need to transform the messages into the correct schema
+
+    templates_params inputs_new = inputs;
+    json & messages = inputs_new.messages;
+
+    GGML_ASSERT(messages.is_array());
+    for (auto & message : messages) {
+        if (message.contains("role") && message["role"].get<std::string>() != "user") {
+            continue;
+        }
+        if (!message.contains("content")) {
+            message["content"] = json::array();
+        }
+        if (message.contains("content") && !message["content"].is_array()) {
+            auto content_str = message["content"].get<std::string>();
+            // default to en-GB if not specified (to make common_chat_format_example works)
+            auto src_lang = message.contains("source_lang_code") ? message["source_lang_code"].get<std::string>() : "en-GB";
+            auto tgt_lang = message.contains("target_lang_code") ? message["target_lang_code"].get<std::string>() : "en-GB";
+            message["content"] = json::array({
+                json{
+                    {"type", "text"},
+                    {"text", content_str},
+                    {"source_lang_code", src_lang},
+                    {"target_lang_code", tgt_lang},
+                }
+            });
+        }
+    }
+
+    data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
+    data.format = COMMON_CHAT_FORMAT_GENERIC;
+
+    return data;
+}
+
 static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     data.prompt = apply(tmpl, inputs);
@@ -3045,6 +3084,12 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_solar_open(tmpl, params);
     }
 
+    // TranslateGemma
+    if (src.find("[source_lang_code]") != std::string::npos &&
+        src.find("[target_lang_code]") != std::string::npos) {
+        return common_chat_params_init_translate_gemma(tmpl, params);
+    }
+
     // Plain handler (no tools)
     if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
         return common_chat_params_init_without_tools(tmpl, params);