Explorar o código

Server: Change Invalid Schema from Server Error (500) to User Error (400) (#17572)

* Make invalid schema a user error (400)

* Move invalid_argument exception handler to ex_wrapper

* Fix test

* Simplify test back to original pattern
Chad Voegele hai 1 mes
pai
achega
c4357dcc35

+ 16 - 16
common/chat.cpp

@@ -163,7 +163,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
     if (tool_choice == "required") {
     if (tool_choice == "required") {
         return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
         return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
     }
     }
-    throw std::runtime_error("Invalid tool_choice: " + tool_choice);
+    throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
 }
 }
 
 
 bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
 bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
@@ -186,17 +186,17 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
     try {
     try {
 
 
         if (!messages.is_array()) {
         if (!messages.is_array()) {
-            throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
+            throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump());
         }
         }
 
 
         for (const auto & message : messages) {
         for (const auto & message : messages) {
             if (!message.is_object()) {
             if (!message.is_object()) {
-                throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
+                throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump());
             }
             }
 
 
             common_chat_msg msg;
             common_chat_msg msg;
             if (!message.contains("role")) {
             if (!message.contains("role")) {
-                throw std::runtime_error("Missing 'role' in message: " + message.dump());
+                throw std::invalid_argument("Missing 'role' in message: " + message.dump());
             }
             }
             msg.role = message.at("role");
             msg.role = message.at("role");
 
 
@@ -209,11 +209,11 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
                 } else if (content.is_array()) {
                 } else if (content.is_array()) {
                     for (const auto & part : content) {
                     for (const auto & part : content) {
                         if (!part.contains("type")) {
                         if (!part.contains("type")) {
-                            throw std::runtime_error("Missing content part type: " + part.dump());
+                            throw std::invalid_argument("Missing content part type: " + part.dump());
                         }
                         }
                         const auto & type = part.at("type");
                         const auto & type = part.at("type");
                         if (type != "text") {
                         if (type != "text") {
-                            throw std::runtime_error("Unsupported content part type: " + type.dump());
+                            throw std::invalid_argument("Unsupported content part type: " + type.dump());
                         }
                         }
                         common_chat_msg_content_part msg_part;
                         common_chat_msg_content_part msg_part;
                         msg_part.type = type;
                         msg_part.type = type;
@@ -221,25 +221,25 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
                         msg.content_parts.push_back(msg_part);
                         msg.content_parts.push_back(msg_part);
                     }
                     }
                 } else if (!content.is_null()) {
                 } else if (!content.is_null()) {
-                    throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
+                    throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
                 }
                 }
             }
             }
             if (has_tool_calls) {
             if (has_tool_calls) {
                 for (const auto & tool_call : message.at("tool_calls")) {
                 for (const auto & tool_call : message.at("tool_calls")) {
                     common_chat_tool_call tc;
                     common_chat_tool_call tc;
                     if (!tool_call.contains("type")) {
                     if (!tool_call.contains("type")) {
-                        throw std::runtime_error("Missing tool call type: " + tool_call.dump());
+                        throw std::invalid_argument("Missing tool call type: " + tool_call.dump());
                     }
                     }
                     const auto & type = tool_call.at("type");
                     const auto & type = tool_call.at("type");
                     if (type != "function") {
                     if (type != "function") {
-                        throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
+                        throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump());
                     }
                     }
                     if (!tool_call.contains("function")) {
                     if (!tool_call.contains("function")) {
-                        throw std::runtime_error("Missing tool call function: " + tool_call.dump());
+                        throw std::invalid_argument("Missing tool call function: " + tool_call.dump());
                     }
                     }
                     const auto & fc = tool_call.at("function");
                     const auto & fc = tool_call.at("function");
                     if (!fc.contains("name")) {
                     if (!fc.contains("name")) {
-                        throw std::runtime_error("Missing tool call name: " + tool_call.dump());
+                        throw std::invalid_argument("Missing tool call name: " + tool_call.dump());
                     }
                     }
                     tc.name = fc.at("name");
                     tc.name = fc.at("name");
                     tc.arguments = fc.at("arguments");
                     tc.arguments = fc.at("arguments");
@@ -250,7 +250,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
                 }
                 }
             }
             }
             if (!has_content && !has_tool_calls) {
             if (!has_content && !has_tool_calls) {
-                throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
+                throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
             }
             }
             if (message.contains("reasoning_content")) {
             if (message.contains("reasoning_content")) {
                 msg.reasoning_content = message.at("reasoning_content");
                 msg.reasoning_content = message.at("reasoning_content");
@@ -353,18 +353,18 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
     try {
     try {
         if (!tools.is_null()) {
         if (!tools.is_null()) {
             if (!tools.is_array()) {
             if (!tools.is_array()) {
-                throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
+                throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump());
             }
             }
             for (const auto & tool : tools) {
             for (const auto & tool : tools) {
                 if (!tool.contains("type")) {
                 if (!tool.contains("type")) {
-                    throw std::runtime_error("Missing tool type: " + tool.dump());
+                    throw std::invalid_argument("Missing tool type: " + tool.dump());
                 }
                 }
                 const auto & type = tool.at("type");
                 const auto & type = tool.at("type");
                 if (!type.is_string() || type != "function") {
                 if (!type.is_string() || type != "function") {
-                    throw std::runtime_error("Unsupported tool type: " + tool.dump());
+                    throw std::invalid_argument("Unsupported tool type: " + tool.dump());
                 }
                 }
                 if (!tool.contains("function")) {
                 if (!tool.contains("function")) {
-                    throw std::runtime_error("Missing tool function: " + tool.dump());
+                    throw std::invalid_argument("Missing tool function: " + tool.dump());
                 }
                 }
 
 
                 const auto & function = tool.at("function");
                 const auto & function = tool.at("function");

+ 1 - 1
common/json-schema-to-grammar.cpp

@@ -974,7 +974,7 @@ public:
 
 
     void check_errors() {
     void check_errors() {
         if (!_errors.empty()) {
         if (!_errors.empty()) {
-            throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
+            throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
         }
         }
         if (!_warnings.empty()) {
         if (!_warnings.empty()) {
             fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
             fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());

+ 1 - 1
tests/test-json-schema-to-grammar.cpp

@@ -1375,7 +1375,7 @@ int main() {
         try {
         try {
             tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
             tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
             tc.verify_status(SUCCESS);
             tc.verify_status(SUCCESS);
-        } catch (const std::runtime_error & ex) {
+        } catch (const std::invalid_argument & ex) {
             fprintf(stderr, "Error: %s\n", ex.what());
             fprintf(stderr, "Error: %s\n", ex.what());
             tc.verify_status(FAILURE);
             tc.verify_status(FAILURE);
         }
         }

+ 18 - 18
tools/server/server-common.cpp

@@ -819,26 +819,26 @@ json oaicompat_chat_params_parse(
             auto schema_wrapper = json_value(response_format, "json_schema", json::object());
             auto schema_wrapper = json_value(response_format, "json_schema", json::object());
             json_schema = json_value(schema_wrapper, "schema", json::object());
             json_schema = json_value(schema_wrapper, "schema", json::object());
         } else if (!response_type.empty() && response_type != "text") {
         } else if (!response_type.empty() && response_type != "text") {
-            throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
+            throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
         }
         }
     }
     }
 
 
     // get input files
     // get input files
     if (!body.contains("messages")) {
     if (!body.contains("messages")) {
-        throw std::runtime_error("'messages' is required");
+        throw std::invalid_argument("'messages' is required");
     }
     }
     json & messages = body.at("messages");
     json & messages = body.at("messages");
     if (!messages.is_array()) {
     if (!messages.is_array()) {
-        throw std::runtime_error("Expected 'messages' to be an array");
+        throw std::invalid_argument("Expected 'messages' to be an array");
     }
     }
     for (auto & msg : messages) {
     for (auto & msg : messages) {
         std::string role = json_value(msg, "role", std::string());
         std::string role = json_value(msg, "role", std::string());
         if (role != "assistant" && !msg.contains("content")) {
         if (role != "assistant" && !msg.contains("content")) {
-            throw std::runtime_error("All non-assistant messages must contain 'content'");
+            throw std::invalid_argument("All non-assistant messages must contain 'content'");
         }
         }
         if (role == "assistant") {
         if (role == "assistant") {
             if (!msg.contains("content") && !msg.contains("tool_calls")) {
             if (!msg.contains("content") && !msg.contains("tool_calls")) {
-                throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
+                throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!");
             }
             }
             if (!msg.contains("content")) {
             if (!msg.contains("content")) {
                 continue; // avoid errors with no content
                 continue; // avoid errors with no content
@@ -850,7 +850,7 @@ json oaicompat_chat_params_parse(
         }
         }
 
 
         if (!content.is_array()) {
         if (!content.is_array()) {
-            throw std::runtime_error("Expected 'content' to be a string or an array");
+            throw std::invalid_argument("Expected 'content' to be a string or an array");
         }
         }
 
 
         for (auto & p : content) {
         for (auto & p : content) {
@@ -884,11 +884,11 @@ json oaicompat_chat_params_parse(
                     // try to decode base64 image
                     // try to decode base64 image
                     std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
                     std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
                     if (parts.size() != 2) {
                     if (parts.size() != 2) {
-                        throw std::runtime_error("Invalid image_url.url value");
+                        throw std::invalid_argument("Invalid image_url.url value");
                     } else if (!string_starts_with(parts[0], "data:image/")) {
                     } else if (!string_starts_with(parts[0], "data:image/")) {
-                        throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
+                        throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
                     } else if (!string_ends_with(parts[0], "base64")) {
                     } else if (!string_ends_with(parts[0], "base64")) {
-                        throw std::runtime_error("image_url.url must be base64 encoded");
+                        throw std::invalid_argument("image_url.url must be base64 encoded");
                     } else {
                     } else {
                         auto base64_data = parts[1];
                         auto base64_data = parts[1];
                         auto decoded_data = base64_decode(base64_data);
                         auto decoded_data = base64_decode(base64_data);
@@ -911,7 +911,7 @@ json oaicompat_chat_params_parse(
                 std::string format = json_value(input_audio, "format", std::string());
                 std::string format = json_value(input_audio, "format", std::string());
                 // while we also support flac, we don't allow it here so we matches the OAI spec
                 // while we also support flac, we don't allow it here so we matches the OAI spec
                 if (format != "wav" && format != "mp3") {
                 if (format != "wav" && format != "mp3") {
-                    throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
+                    throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
                 }
                 }
                 auto decoded_data = base64_decode(data); // expected to be base64 encoded
                 auto decoded_data = base64_decode(data); // expected to be base64 encoded
                 out_files.push_back(decoded_data);
                 out_files.push_back(decoded_data);
@@ -922,7 +922,7 @@ json oaicompat_chat_params_parse(
                 p.erase("input_audio");
                 p.erase("input_audio");
 
 
             } else if (type != "text") {
             } else if (type != "text") {
-                throw std::runtime_error("unsupported content[].type");
+                throw std::invalid_argument("unsupported content[].type");
             }
             }
         }
         }
     }
     }
@@ -940,7 +940,7 @@ json oaicompat_chat_params_parse(
     inputs.enable_thinking       = opt.enable_thinking;
     inputs.enable_thinking       = opt.enable_thinking;
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
         if (body.contains("grammar")) {
         if (body.contains("grammar")) {
-            throw std::runtime_error("Cannot use custom grammar constraints with tools.");
+            throw std::invalid_argument("Cannot use custom grammar constraints with tools.");
         }
         }
         llama_params["parse_tool_calls"] = true;
         llama_params["parse_tool_calls"] = true;
     }
     }
@@ -959,7 +959,7 @@ json oaicompat_chat_params_parse(
     } else if (enable_thinking_kwarg == "false") {
     } else if (enable_thinking_kwarg == "false") {
         inputs.enable_thinking = false;
         inputs.enable_thinking = false;
     } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
     } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
-        throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
+        throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)");
     }
     }
 
 
     // if the assistant message appears at the end of list, we do not add end-of-turn token
     // if the assistant message appears at the end of list, we do not add end-of-turn token
@@ -972,14 +972,14 @@ json oaicompat_chat_params_parse(
 
 
         /* sanity check, max one assistant message at the end of the list */
         /* sanity check, max one assistant message at the end of the list */
         if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
         if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
-            throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
+            throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
         }
         }
 
 
         /* TODO: test this properly */
         /* TODO: test this properly */
         inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
         inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
 
 
         if ( inputs.enable_thinking ) {
         if ( inputs.enable_thinking ) {
-            throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
+            throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking.");
         }
         }
 
 
         inputs.add_generation_prompt = true;
         inputs.add_generation_prompt = true;
@@ -1020,18 +1020,18 @@ json oaicompat_chat_params_parse(
     // Handle "n" field
     // Handle "n" field
     int n_choices = json_value(body, "n", 1);
     int n_choices = json_value(body, "n", 1);
     if (n_choices != 1) {
     if (n_choices != 1) {
-        throw std::runtime_error("Only one completion choice is allowed");
+        throw std::invalid_argument("Only one completion choice is allowed");
     }
     }
 
 
     // Handle "logprobs" field
     // Handle "logprobs" field
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     if (json_value(body, "logprobs", false)) {
     if (json_value(body, "logprobs", false)) {
         if (has_tools && stream) {
         if (has_tools && stream) {
-            throw std::runtime_error("logprobs is not supported with tools + stream");
+            throw std::invalid_argument("logprobs is not supported with tools + stream");
         }
         }
         llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
         llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
     } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
     } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
-        throw std::runtime_error("top_logprobs requires logprobs to be set to true");
+        throw std::invalid_argument("top_logprobs requires logprobs to be set to true");
     }
     }
 
 
     // Copy remaining properties to llama_params
     // Copy remaining properties to llama_params

+ 7 - 1
tools/server/server.cpp

@@ -34,18 +34,24 @@ static inline void signal_handler(int signal) {
 static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
 static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
     return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
     return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
         std::string message;
         std::string message;
+        error_type error;
         try {
         try {
             return func(req);
             return func(req);
+        } catch (const std::invalid_argument & e) {
+            error = ERROR_TYPE_INVALID_REQUEST;
+            message = e.what();
         } catch (const std::exception & e) {
         } catch (const std::exception & e) {
+            error = ERROR_TYPE_SERVER;
             message = e.what();
             message = e.what();
         } catch (...) {
         } catch (...) {
+            error = ERROR_TYPE_SERVER;
             message = "unknown error";
             message = "unknown error";
         }
         }
 
 
         auto res = std::make_unique<server_http_res>();
         auto res = std::make_unique<server_http_res>();
         res->status = 500;
         res->status = 500;
         try {
         try {
-            json error_data = format_error_response(message, ERROR_TYPE_SERVER);
+            json error_data = format_error_response(message, error);
             res->status = json_value(error_data, "code", 500);
             res->status = json_value(error_data, "code", 500);
             res->data = safe_json_to_str({{ "error", error_data }});
             res->data = safe_json_to_str({{ "error", error_data }});
             SRV_WRN("got exception: %s\n", res->data.c_str());
             SRV_WRN("got exception: %s\n", res->data.c_str());

+ 1 - 1
tools/server/tests/unit/test_chat_completion.py

@@ -199,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
         choice = res.body["choices"][0]
         choice = res.body["choices"][0]
         assert match_regex(re_content, choice["message"]["content"])
         assert match_regex(re_content, choice["message"]["content"])
     else:
     else:
-        assert res.status_code != 200
+        assert res.status_code == 400
         assert "error" in res.body
         assert "error" in res.body