Initial tool call support
This commit is contained in:
parent
a2c35998ce
commit
9f09745e05
|
|
@ -1090,30 +1090,52 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
}
|
||||
|
||||
if (input_value.is_string()) {
|
||||
// #responses_create-input-text_input
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", input_value},
|
||||
});
|
||||
} else if (input_value.is_array()) {
|
||||
for (const auto & input_message : input_value) {
|
||||
if (!input_message.contains("content")) {
|
||||
throw std::invalid_argument("'content' is required");
|
||||
// #responses_create-input-input_item_list
|
||||
|
||||
const auto exists_and_is_array = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_array();
|
||||
};
|
||||
const auto exists_and_is_string = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_string();
|
||||
};
|
||||
|
||||
for (json item : input_value) {
|
||||
if (exists_and_is_string(item, "content")) {
|
||||
// #responses_create-input-input_item_list-input_message-content-text_input
|
||||
// Only "Input message" contains item["content"]::string
|
||||
// After converting item["content"]::string to item["content"]::array,
|
||||
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
|
||||
item["content"] = json::array({
|
||||
json {
|
||||
{"text", item.at("content")},
|
||||
{"type", "input_text"}
|
||||
}
|
||||
});
|
||||
}
|
||||
const json content = input_message.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
chatcmpl_messages.push_back(input_message);
|
||||
} else if (content.is_array()) {
|
||||
json new_content = json::array();
|
||||
if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
(item.at("role") == "user" ||
|
||||
item.at("role") == "system" ||
|
||||
item.at("role") == "developer")
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-input_message
|
||||
json chatcmpl_content = json::array();
|
||||
|
||||
for (const auto & input_item : content) {
|
||||
for (const json & input_item : item.at("content")) {
|
||||
const std::string type = json_value(input_item, "type", std::string());
|
||||
|
||||
if (type == "input_text") {
|
||||
if (!input_item.contains("text")) {
|
||||
throw std::invalid_argument("'Input text' requires 'text'");
|
||||
}
|
||||
new_content.push_back({
|
||||
chatcmpl_content.push_back({
|
||||
{"text", input_item.at("text")},
|
||||
{"type", "text"}
|
||||
});
|
||||
|
|
@ -1124,7 +1146,7 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
if (!input_item.contains("image_url")) {
|
||||
throw std::invalid_argument("'image_url' is required");
|
||||
}
|
||||
new_content.push_back({
|
||||
chatcmpl_content.push_back({
|
||||
{"image_url", json {{"url", input_item.at("image_url")}}},
|
||||
{"type", "image_url"}
|
||||
});
|
||||
|
|
@ -1136,7 +1158,7 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
if (!input_item.contains("file_data") || !input_item.contains("filename")) {
|
||||
throw std::invalid_argument("Both 'file_data' and 'filename' are required");
|
||||
}
|
||||
new_content.push_back({
|
||||
chatcmpl_content.push_back({
|
||||
{"file", json {
|
||||
{"file_data", input_item.at("file_data")},
|
||||
{"filename", input_item.at("filename")}}},
|
||||
|
|
@ -1147,21 +1169,87 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
}
|
||||
}
|
||||
|
||||
json new_input_message = input_message;
|
||||
new_input_message["content"] = new_content;
|
||||
item["content"] = chatcmpl_content;
|
||||
|
||||
chatcmpl_messages.push_back(new_input_message);
|
||||
chatcmpl_messages.push_back(item);
|
||||
} else if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
item.at("role") == "assistant" &&
|
||||
exists_and_is_string(item, "status") &&
|
||||
(item.at("status") == "in_progress" ||
|
||||
item.at("status") == "completed" ||
|
||||
item.at("status") == "incomplete") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "message"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-output_message
|
||||
json chatcmpl_content = json::array();
|
||||
|
||||
for (const auto & output_text : item.at("content")) {
|
||||
const std::string type = json_value(output_text, "type", std::string());
|
||||
if (type != "output_text") {
|
||||
throw std::invalid_argument("'type' must be 'output_text'");
|
||||
}
|
||||
if (!exists_and_is_string(output_text, "text")) {
|
||||
throw std::invalid_argument("'Output text' requires 'text'");
|
||||
}
|
||||
// Ignore annotations and logprobs for now
|
||||
chatcmpl_content.push_back({
|
||||
{"text", output_text.at("text")},
|
||||
{"type", "text"}
|
||||
});
|
||||
}
|
||||
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
} else if (exists_and_is_string(item, "arguments") &&
|
||||
exists_and_is_string(item, "call_id") &&
|
||||
exists_and_is_string(item, "name") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({ json {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")}
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"}
|
||||
}})},
|
||||
});
|
||||
} else if (exists_and_is_string(item, "call_id") &&
|
||||
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call_output"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call_output
|
||||
if (item.at("output").is_string()) {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", item.at("output")},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")}
|
||||
});
|
||||
} else {
|
||||
json chatcmpl_outputs = item.at("output");
|
||||
for (json & chatcmpl_output : chatcmpl_outputs) {
|
||||
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
|
||||
throw std::invalid_argument("Output of tool call should be 'Input text'");
|
||||
}
|
||||
chatcmpl_output["type"] = "text";
|
||||
}
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", chatcmpl_outputs},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("'content' must be a string or array of objects");
|
||||
}
|
||||
|
||||
const std::string role = json_value(input_message, "role", std::string());
|
||||
if (role != "user" && role != "assistant" && role != "system" && role != "developer") {
|
||||
throw std::invalid_argument("'role' must be one of user, assistant, system, or developer");
|
||||
}
|
||||
|
||||
if (input_message.contains("type") && input_message.at("type") != "message") {
|
||||
throw std::invalid_argument("If 'type' is defined, it should be 'message'");
|
||||
throw std::invalid_argument("Cannot determine type of 'item'");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1172,6 +1260,30 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
chatcmpl_body.erase("input");
|
||||
chatcmpl_body["messages"] = chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("tools")) {
|
||||
if (!response_body.at("tools").is_array()) {
|
||||
throw std::invalid_argument("'tools' must be an array of objects");
|
||||
}
|
||||
json chatcmpl_tools = json::array();
|
||||
for (json resp_tool : response_body.at("tools")) {
|
||||
json chatcmpl_tool;
|
||||
|
||||
if (json_value(resp_tool, "type", std::string()) != "function") {
|
||||
throw std::invalid_argument("'type' of tool must be 'function'");
|
||||
}
|
||||
resp_tool.erase("type");
|
||||
chatcmpl_tool["type"] = "function";
|
||||
|
||||
if (!resp_tool.contains("strict")) {
|
||||
resp_tool["strict"] = true;
|
||||
}
|
||||
chatcmpl_tool["function"] = resp_tool;
|
||||
chatcmpl_tools.push_back(chatcmpl_tool);
|
||||
}
|
||||
chatcmpl_body.erase("tools");
|
||||
chatcmpl_body["tools"] = chatcmpl_tools;
|
||||
}
|
||||
|
||||
if (response_body.contains("max_output_tokens")) {
|
||||
chatcmpl_body.erase("max_output_tokens");
|
||||
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
|
||||
|
|
|
|||
|
|
@ -851,44 +851,69 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
|
|||
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
||||
json server_sent_events = json::array();
|
||||
json output = json::array();
|
||||
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.output_text.done"},
|
||||
{"data", json {
|
||||
{"type", "response.output_text.done"},
|
||||
{"text", oaicompat_msg.content}
|
||||
}}
|
||||
});
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.output_item.done"},
|
||||
{"data", json {
|
||||
{"type", "response.output_item.done"},
|
||||
{"item", json {
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "call_dummy_id"},
|
||||
{"name", tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
output.push_back({
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"name", tool_call.name}
|
||||
});
|
||||
}
|
||||
|
||||
const json part = {
|
||||
{"type", "output_text"},
|
||||
{"annotations", json::array()},
|
||||
{"logprobs", json::array()},
|
||||
{"text", oaicompat_msg.content}
|
||||
};
|
||||
if (oaicompat_msg.content != "") {
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.output_text.done"},
|
||||
{"data", json {
|
||||
{"type", "response.output_text.done"},
|
||||
{"text", oaicompat_msg.content}
|
||||
}}
|
||||
});
|
||||
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.content_part.done"},
|
||||
{"data", json {
|
||||
{"type", "response.content_part.done"},
|
||||
{"part", part}
|
||||
}}
|
||||
});
|
||||
const json part = {
|
||||
{"type", "output_text"},
|
||||
{"annotations", json::array()},
|
||||
{"logprobs", json::array()},
|
||||
{"text", oaicompat_msg.content}
|
||||
};
|
||||
|
||||
const json item = {
|
||||
{"type", "message"},
|
||||
{"status", "completed"},
|
||||
{"content", json::array({part})},
|
||||
{"role", "assistant"}
|
||||
};
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.content_part.done"},
|
||||
{"data", json {
|
||||
{"type", "response.content_part.done"},
|
||||
{"part", part}
|
||||
}}
|
||||
});
|
||||
const json item = {
|
||||
{"type", "message"},
|
||||
{"status", "completed"},
|
||||
{"content", json::array({part})},
|
||||
{"role", "assistant"}
|
||||
};
|
||||
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.output_item.done"},
|
||||
{"data", json {
|
||||
{"type", "response.output_item.done"},
|
||||
{"item", item}
|
||||
}}
|
||||
});
|
||||
server_sent_events.push_back(json {
|
||||
{"event", "response.output_item.done"},
|
||||
{"data", json {
|
||||
{"type", "response.output_item.done"},
|
||||
{"item", item}
|
||||
}}
|
||||
});
|
||||
output.push_back(item);
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
server_sent_events.push_back(json {
|
||||
|
|
@ -896,11 +921,12 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
|||
{"data", json {
|
||||
{"type", "response.completed"},
|
||||
{"response", json {
|
||||
{"id", "resp_dummy_id"},
|
||||
{"object", "response"},
|
||||
{"created_at", t},
|
||||
{"status", "completed"},
|
||||
{"model", oaicompat_model},
|
||||
{"output", json::array({item})},
|
||||
{"output", output},
|
||||
{"usage", json {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded},
|
||||
|
|
@ -1191,6 +1217,28 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
|
|||
}}
|
||||
});
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
deltas.push_back(json {
|
||||
{"event", "response.output_item.added"},
|
||||
{"data", json {
|
||||
{"type", "response.output_item.added"},
|
||||
{"item", json {
|
||||
{"type", "function_call"},
|
||||
{"status", "in_progress"},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
deltas.push_back(json {
|
||||
{"event", "response.function_call_arguments.delta"},
|
||||
{"data", json {
|
||||
{"type", "response.function_call_arguments.delta"},
|
||||
{"delta", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
});
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
deltas.push_back(json {
|
||||
{"event", "response.output_text.delta"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue