Procházet zdrojové kódy

allow missing content in message if tool_calls provided (#12293)

Olivier Chafik před 10 měsíci
rodič
revize
87c2630546
2 změnil soubory, kde provedl 31 přidání a 13 odebrání
  1. 16 13
      common/chat.cpp
  2. 15 0
      tests/test-chat.cpp

+ 16 - 13
common/chat.cpp

@@ -60,7 +60,9 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
             }
             }
             msg.role = message.at("role");
             msg.role = message.at("role");
 
 
-            if (message.contains("content")) {
+            auto has_content = message.contains("content");
+            auto has_tool_calls = message.contains("tool_calls");
+            if (has_content) {
                 const auto & content = message.at("content");
                 const auto & content = message.at("content");
                 if (content.is_string()) {
                 if (content.is_string()) {
                     msg.content = content;
                     msg.content = content;
@@ -81,19 +83,8 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
                 } else if (!content.is_null()) {
                 } else if (!content.is_null()) {
                     throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
                     throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
                 }
                 }
-            } else {
-                throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
-            }
-            if (message.contains("reasoning_content")) {
-                msg.reasoning_content = message.at("reasoning_content");
-            }
-            if (message.contains("name")) {
-                msg.tool_name = message.at("name");
-            }
-            if (message.contains("tool_call_id")) {
-                msg.tool_call_id = message.at("tool_call_id");
             }
             }
-            if (message.contains("tool_calls")) {
+            if (has_tool_calls) {
                 for (const auto & tool_call : message.at("tool_calls")) {
                 for (const auto & tool_call : message.at("tool_calls")) {
                     common_chat_tool_call tc;
                     common_chat_tool_call tc;
                     if (!tool_call.contains("type")) {
                     if (!tool_call.contains("type")) {
@@ -118,6 +109,18 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
                     msg.tool_calls.push_back(tc);
                     msg.tool_calls.push_back(tc);
                 }
                 }
             }
             }
+            if (!has_content && !has_tool_calls) {
+                throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
+            }
+            if (message.contains("reasoning_content")) {
+                msg.reasoning_content = message.at("reasoning_content");
+            }
+            if (message.contains("name")) {
+                msg.tool_name = message.at("name");
+            }
+            if (message.contains("tool_call_id")) {
+                msg.tool_call_id = message.at("tool_call_id");
+            }
 
 
             msgs.push_back(msg);
             msgs.push_back(msg);
         }
         }

+ 15 - 0
tests/test-chat.cpp

@@ -480,6 +480,21 @@ static void test_msgs_oaicompat_json_conversion() {
             "]"
             "]"
         ),
         ),
         common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
         common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
+
+    auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
+    assert_equals<size_t>(1, res.size());
+    assert_equals<std::string>(res[0].role, "assistant");
+    assert_equals(true, res[0].content.empty());
+    assert_equals(true, res[0].tool_calls.empty());
+
+    try {
+        common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
+        throw std::runtime_error("Expected exception");
+    } catch (const std::exception & e) {
+        if (std::string(e.what()).find("'content'") == std::string::npos) {
+            throw std::runtime_error("Expected exception about missing 'content'");
+        }
+    }
 }
 }
 
 
 static void test_tools_oaicompat_json_conversion() {
 static void test_tools_oaicompat_json_conversion() {