Explorar o código

`tool-call`: fix non-tool-calling grammar crashes w/ Qwen / Hermes 2 templates (#12900)

* `tool-call`: don't call common_chat_params_init_hermes_2_pro when there aren't tools (or when there's a schema)

* test all chat formats w/o tools
Olivier Chafik hai 9 meses
pai
achega
b6930ebc42
Modificáronse 2 ficheiros con 10 adicións e 1 borrados
  1. 1 1
      common/chat.cpp
  2. 9 0
      tests/test-chat.cpp

+ 1 - 1
common/chat.cpp

@@ -1622,7 +1622,7 @@ static common_chat_params common_chat_templates_apply_jinja(
     }
     }
 
 
     // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
     // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
-    if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
+    if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) {
         return common_chat_params_init_hermes_2_pro(tmpl, params);
         return common_chat_params_init_hermes_2_pro(tmpl, params);
     }
     }
 
 

+ 9 - 0
tests/test-chat.cpp

@@ -569,6 +569,7 @@ static void test_template_output_parsers() {
     {
     {
         // Not supported yet
         // Not supported yet
         auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
         auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
     }
     }
     {
     {
@@ -665,6 +666,7 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
         auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
         std::vector<std::string> end_tokens{ "<|im_end|>" };
         std::vector<std::string> end_tokens{ "<|im_end|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(
         assert_equals(
             COMMON_CHAT_FORMAT_HERMES_2_PRO,
             COMMON_CHAT_FORMAT_HERMES_2_PRO,
@@ -793,6 +795,7 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
         auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
                       common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
                       common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
@@ -815,6 +818,7 @@ static void test_template_output_parsers() {
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
 
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
 
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -824,6 +828,8 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
         auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
         std::vector<std::string>   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
+                      common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
                       common_chat_templates_apply(tmpls.get(), inputs_tools).format);
                       common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
 
@@ -851,6 +857,7 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
         auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
         std::vector<std::string>   end_tokens{ "<|eot_id|>" };
         std::vector<std::string>   end_tokens{ "<|eot_id|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
@@ -862,6 +869,7 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
         auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
 
@@ -891,6 +899,7 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
         auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
         std::vector<std::string>   end_tokens{ "<|end▁of▁sentence|>" };
 
 
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);