Server: Change Invalid Schema from Server Error (500) to User Error (400) (#17572)

* Make invalid schema a user error (400)

* Move invalid_argument exception handler to ex_wrapper

* Fix test

* Simplify test back to original pattern
This commit is contained in:
Chad Voegele 2025-12-02 10:33:50 -06:00 committed by GitHub
parent e148380c7c
commit c4357dcc35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 44 additions and 38 deletions

View File

@ -163,7 +163,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
if (tool_choice == "required") {
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
}
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
}
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
@ -186,17 +186,17 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
try {
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump());
}
for (const auto & message : messages) {
if (!message.is_object()) {
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump());
}
common_chat_msg msg;
if (!message.contains("role")) {
throw std::runtime_error("Missing 'role' in message: " + message.dump());
throw std::invalid_argument("Missing 'role' in message: " + message.dump());
}
msg.role = message.at("role");
@ -209,11 +209,11 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
} else if (content.is_array()) {
for (const auto & part : content) {
if (!part.contains("type")) {
throw std::runtime_error("Missing content part type: " + part.dump());
throw std::invalid_argument("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
throw std::runtime_error("Unsupported content part type: " + type.dump());
throw std::invalid_argument("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
msg_part.type = type;
@ -221,25 +221,25 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
msg.content_parts.push_back(msg_part);
}
} else if (!content.is_null()) {
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
}
if (has_tool_calls) {
for (const auto & tool_call : message.at("tool_calls")) {
common_chat_tool_call tc;
if (!tool_call.contains("type")) {
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
throw std::invalid_argument("Missing tool call type: " + tool_call.dump());
}
const auto & type = tool_call.at("type");
if (type != "function") {
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump());
}
if (!tool_call.contains("function")) {
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
throw std::invalid_argument("Missing tool call function: " + tool_call.dump());
}
const auto & fc = tool_call.at("function");
if (!fc.contains("name")) {
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
throw std::invalid_argument("Missing tool call name: " + tool_call.dump());
}
tc.name = fc.at("name");
tc.arguments = fc.at("arguments");
@ -250,7 +250,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
}
}
if (!has_content && !has_tool_calls) {
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
}
if (message.contains("reasoning_content")) {
msg.reasoning_content = message.at("reasoning_content");
@ -353,18 +353,18 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
try {
if (!tools.is_null()) {
if (!tools.is_array()) {
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump());
}
for (const auto & tool : tools) {
if (!tool.contains("type")) {
throw std::runtime_error("Missing tool type: " + tool.dump());
throw std::invalid_argument("Missing tool type: " + tool.dump());
}
const auto & type = tool.at("type");
if (!type.is_string() || type != "function") {
throw std::runtime_error("Unsupported tool type: " + tool.dump());
throw std::invalid_argument("Unsupported tool type: " + tool.dump());
}
if (!tool.contains("function")) {
throw std::runtime_error("Missing tool function: " + tool.dump());
throw std::invalid_argument("Missing tool function: " + tool.dump());
}
const auto & function = tool.at("function");

View File

@ -974,7 +974,7 @@ public:
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());

View File

@ -1375,7 +1375,7 @@ int main() {
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
} catch (const std::invalid_argument & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
tc.verify_status(FAILURE);
}

View File

@ -819,26 +819,26 @@ json oaicompat_chat_params_parse(
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
json_schema = json_value(schema_wrapper, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
// get input files
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
throw std::invalid_argument("'messages' is required");
}
json & messages = body.at("messages");
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array");
throw std::invalid_argument("Expected 'messages' to be an array");
}
for (auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (role != "assistant" && !msg.contains("content")) {
throw std::runtime_error("All non-assistant messages must contain 'content'");
throw std::invalid_argument("All non-assistant messages must contain 'content'");
}
if (role == "assistant") {
if (!msg.contains("content") && !msg.contains("tool_calls")) {
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!");
}
if (!msg.contains("content")) {
continue; // avoid errors with no content
@ -850,7 +850,7 @@ json oaicompat_chat_params_parse(
}
if (!content.is_array()) {
throw std::runtime_error("Expected 'content' to be a string or an array");
throw std::invalid_argument("Expected 'content' to be a string or an array");
}
for (auto & p : content) {
@ -884,11 +884,11 @@ json oaicompat_chat_params_parse(
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
if (parts.size() != 2) {
throw std::runtime_error("Invalid image_url.url value");
throw std::invalid_argument("Invalid image_url.url value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::runtime_error("image_url.url must be base64 encoded");
throw std::invalid_argument("image_url.url must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
@ -911,7 +911,7 @@ json oaicompat_chat_params_parse(
std::string format = json_value(input_audio, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
@ -922,7 +922,7 @@ json oaicompat_chat_params_parse(
p.erase("input_audio");
} else if (type != "text") {
throw std::runtime_error("unsupported content[].type");
throw std::invalid_argument("unsupported content[].type");
}
}
}
@ -940,7 +940,7 @@ json oaicompat_chat_params_parse(
inputs.enable_thinking = opt.enable_thinking;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
if (body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
throw std::invalid_argument("Cannot use custom grammar constraints with tools.");
}
llama_params["parse_tool_calls"] = true;
}
@ -959,7 +959,7 @@ json oaicompat_chat_params_parse(
} else if (enable_thinking_kwarg == "false") {
inputs.enable_thinking = false;
} else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)");
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
@ -972,14 +972,14 @@ json oaicompat_chat_params_parse(
/* sanity check, max one assistant message at the end of the list */
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
}
/* TODO: test this properly */
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
if ( inputs.enable_thinking ) {
throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking.");
}
inputs.add_generation_prompt = true;
@ -1020,18 +1020,18 @@ json oaicompat_chat_params_parse(
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
throw std::invalid_argument("Only one completion choice is allowed");
}
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
throw std::invalid_argument("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
throw std::invalid_argument("top_logprobs requires logprobs to be set to true");
}
// Copy remaining properties to llama_params

View File

@ -34,18 +34,24 @@ static inline void signal_handler(int signal) {
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
std::string message;
error_type error;
try {
return func(req);
} catch (const std::invalid_argument & e) {
error = ERROR_TYPE_INVALID_REQUEST;
message = e.what();
} catch (const std::exception & e) {
error = ERROR_TYPE_SERVER;
message = e.what();
} catch (...) {
error = ERROR_TYPE_SERVER;
message = "unknown error";
}
auto res = std::make_unique<server_http_res>();
res->status = 500;
try {
json error_data = format_error_response(message, ERROR_TYPE_SERVER);
json error_data = format_error_response(message, error);
res->status = json_value(error_data, "code", 500);
res->data = safe_json_to_str({{ "error", error_data }});
SRV_WRN("got exception: %s\n", res->data.c_str());

View File

@ -199,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert res.status_code == 400
assert "error" in res.body