server: add --media-path for local media files (#17697)
* server: add --media-path for local media files * remove unused fn
This commit is contained in:
parent
a96283adc4
commit
13628d8bdb
|
|
@ -2488,12 +2488,29 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
"path to save slot kv cache (default: disabled)",
|
"path to save slot kv cache (default: disabled)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.slot_save_path = value;
|
params.slot_save_path = value;
|
||||||
|
if (!fs_is_directory(params.slot_save_path)) {
|
||||||
|
throw std::invalid_argument("not a directory: " + value);
|
||||||
|
}
|
||||||
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
||||||
if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
||||||
params.slot_save_path += DIRECTORY_SEPARATOR;
|
params.slot_save_path += DIRECTORY_SEPARATOR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--media-path"}, "PATH",
|
||||||
|
"directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.media_path = value;
|
||||||
|
if (!fs_is_directory(params.media_path)) {
|
||||||
|
throw std::invalid_argument("not a directory: " + value);
|
||||||
|
}
|
||||||
|
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
||||||
|
if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
||||||
|
params.media_path += DIRECTORY_SEPARATOR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--models-dir"}, "PATH",
|
{"--models-dir"}, "PATH",
|
||||||
"directory containing models for the router server (default: disabled)",
|
"directory containing models for the router server (default: disabled)",
|
||||||
|
|
|
||||||
|
|
@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||||
|
|
||||||
// Validate if a filename is safe to use
|
// Validate if a filename is safe to use
|
||||||
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||||
bool fs_validate_filename(const std::string & filename) {
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||||
if (!filename.length()) {
|
if (!filename.length()) {
|
||||||
// Empty filename invalid
|
// Empty filename invalid
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -754,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|
||||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||||
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
|| c == ':' || c == '*' // Illegal characters
|
||||||
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!allow_subdirs && (c == '/' || c == '\\')) {
|
||||||
|
// Subdirectories not allowed, reject path separators
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||||
|
|
@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fs_is_directory(const std::string & path) {
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
|
||||||
|
}
|
||||||
|
|
||||||
std::string fs_get_cache_directory() {
|
std::string fs_get_cache_directory() {
|
||||||
std::string cache_directory = "";
|
std::string cache_directory = "";
|
||||||
auto ensure_trailing_slash = [](std::string p) {
|
auto ensure_trailing_slash = [](std::string p) {
|
||||||
|
|
|
||||||
|
|
@ -485,6 +485,7 @@ struct common_params {
|
||||||
bool log_json = false;
|
bool log_json = false;
|
||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
std::string media_path; // path to directory for loading media files
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
|
|
@ -635,8 +636,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
|
||||||
bool fs_validate_filename(const std::string & filename);
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||||
bool fs_create_directory_with_parents(const std::string & path);
|
bool fs_create_directory_with_parents(const std::string & path);
|
||||||
|
bool fs_is_directory(const std::string & path);
|
||||||
|
|
||||||
std::string fs_get_cache_directory();
|
std::string fs_get_cache_directory();
|
||||||
std::string fs_get_cache_file(const std::string & filename);
|
std::string fs_get_cache_file(const std::string & filename);
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
json format_error_response(const std::string & message, const enum error_type type) {
|
json format_error_response(const std::string & message, const enum error_type type) {
|
||||||
std::string type_str;
|
std::string type_str;
|
||||||
|
|
@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) {
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// media_path always end with '/', see arg.cpp
|
||||||
|
static void handle_media(
|
||||||
|
std::vector<raw_buffer> & out_files,
|
||||||
|
json & media_obj,
|
||||||
|
const std::string & media_path) {
|
||||||
|
std::string url = json_value(media_obj, "url", std::string());
|
||||||
|
if (string_starts_with(url, "http")) {
|
||||||
|
// download remote image
|
||||||
|
// TODO @ngxson : maybe make these params configurable
|
||||||
|
common_remote_params params;
|
||||||
|
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
|
||||||
|
params.max_size = 1024 * 1024 * 10; // 10MB
|
||||||
|
params.timeout = 10; // seconds
|
||||||
|
SRV_INF("downloading image from '%s'\n", url.c_str());
|
||||||
|
auto res = common_remote_get_content(url, params);
|
||||||
|
if (200 <= res.first && res.first < 300) {
|
||||||
|
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
||||||
|
raw_buffer data;
|
||||||
|
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||||
|
out_files.push_back(data);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Failed to download image");
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (string_starts_with(url, "file://")) {
|
||||||
|
if (media_path.empty()) {
|
||||||
|
throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified");
|
||||||
|
}
|
||||||
|
// load local image file
|
||||||
|
std::string file_path = url.substr(7); // remove "file://"
|
||||||
|
raw_buffer data;
|
||||||
|
if (!fs_validate_filename(file_path, true)) {
|
||||||
|
throw std::invalid_argument("file path is not allowed: " + file_path);
|
||||||
|
}
|
||||||
|
SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str());
|
||||||
|
std::ifstream file(media_path + file_path, std::ios::binary);
|
||||||
|
if (!file) {
|
||||||
|
throw std::invalid_argument("file does not exist or cannot be opened: " + file_path);
|
||||||
|
}
|
||||||
|
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||||
|
out_files.push_back(data);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// try to decode base64 image
|
||||||
|
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||||
|
if (parts.size() != 2) {
|
||||||
|
throw std::runtime_error("Invalid url value");
|
||||||
|
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||||
|
throw std::runtime_error("Invalid url format: " + parts[0]);
|
||||||
|
} else if (!string_ends_with(parts[0], "base64")) {
|
||||||
|
throw std::runtime_error("url must be base64 encoded");
|
||||||
|
} else {
|
||||||
|
auto base64_data = parts[1];
|
||||||
|
auto decoded_data = base64_decode(base64_data);
|
||||||
|
out_files.push_back(decoded_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// used by /chat/completions endpoint
|
// used by /chat/completions endpoint
|
||||||
json oaicompat_chat_params_parse(
|
json oaicompat_chat_params_parse(
|
||||||
json & body, /* openai api json semantics */
|
json & body, /* openai api json semantics */
|
||||||
|
|
@ -860,41 +920,8 @@ json oaicompat_chat_params_parse(
|
||||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||||
}
|
}
|
||||||
|
|
||||||
json image_url = json_value(p, "image_url", json::object());
|
json image_url = json_value(p, "image_url", json::object());
|
||||||
std::string url = json_value(image_url, "url", std::string());
|
handle_media(out_files, image_url, opt.media_path);
|
||||||
if (string_starts_with(url, "http")) {
|
|
||||||
// download remote image
|
|
||||||
// TODO @ngxson : maybe make these params configurable
|
|
||||||
common_remote_params params;
|
|
||||||
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
|
|
||||||
params.max_size = 1024 * 1024 * 10; // 10MB
|
|
||||||
params.timeout = 10; // seconds
|
|
||||||
SRV_INF("downloading image from '%s'\n", url.c_str());
|
|
||||||
auto res = common_remote_get_content(url, params);
|
|
||||||
if (200 <= res.first && res.first < 300) {
|
|
||||||
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
|
||||||
raw_buffer data;
|
|
||||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
|
||||||
out_files.push_back(data);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Failed to download image");
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// try to decode base64 image
|
|
||||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
|
||||||
if (parts.size() != 2) {
|
|
||||||
throw std::invalid_argument("Invalid image_url.url value");
|
|
||||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
|
||||||
throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
|
|
||||||
} else if (!string_ends_with(parts[0], "base64")) {
|
|
||||||
throw std::invalid_argument("image_url.url must be base64 encoded");
|
|
||||||
} else {
|
|
||||||
auto base64_data = parts[1];
|
|
||||||
auto decoded_data = base64_decode(base64_data);
|
|
||||||
out_files.push_back(decoded_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// replace this chunk with a marker
|
// replace this chunk with a marker
|
||||||
p["type"] = "text";
|
p["type"] = "text";
|
||||||
|
|
@ -916,6 +943,8 @@ json oaicompat_chat_params_parse(
|
||||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||||
out_files.push_back(decoded_data);
|
out_files.push_back(decoded_data);
|
||||||
|
|
||||||
|
// TODO: add audio_url support by reusing handle_media()
|
||||||
|
|
||||||
// replace this chunk with a marker
|
// replace this chunk with a marker
|
||||||
p["type"] = "text";
|
p["type"] = "text";
|
||||||
p["text"] = mtmd_default_marker();
|
p["text"] = mtmd_default_marker();
|
||||||
|
|
|
||||||
|
|
@ -284,6 +284,7 @@ struct oaicompat_parser_options {
|
||||||
bool allow_image;
|
bool allow_image;
|
||||||
bool allow_audio;
|
bool allow_audio;
|
||||||
bool enable_thinking = true;
|
bool enable_thinking = true;
|
||||||
|
std::string media_path;
|
||||||
};
|
};
|
||||||
|
|
||||||
// used by /chat/completions endpoint
|
// used by /chat/completions endpoint
|
||||||
|
|
|
||||||
|
|
@ -788,6 +788,7 @@ struct server_context_impl {
|
||||||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||||
/* enable_thinking */ enable_thinking,
|
/* enable_thinking */ enable_thinking,
|
||||||
|
/* media_path */ params_base.media_path,
|
||||||
};
|
};
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,11 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
|
||||||
try {
|
try {
|
||||||
return func(req);
|
return func(req);
|
||||||
} catch (const std::invalid_argument & e) {
|
} catch (const std::invalid_argument & e) {
|
||||||
|
// treat invalid_argument as invalid request (400)
|
||||||
error = ERROR_TYPE_INVALID_REQUEST;
|
error = ERROR_TYPE_INVALID_REQUEST;
|
||||||
message = e.what();
|
message = e.what();
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
|
// treat other exceptions as server error (500)
|
||||||
error = ERROR_TYPE_SERVER;
|
error = ERROR_TYPE_SERVER;
|
||||||
message = e.what();
|
message = e.what();
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
|
|
|
||||||
|
|
@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert cors_header in res.headers
|
assert cors_header in res.headers
|
||||||
assert res.headers[cors_header] == cors_header_value
|
assert res.headers[cors_header] == cors_header_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"media_path, image_url, success",
|
||||||
|
[
|
||||||
|
(None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
|
||||||
|
("../../../tools", "file://mtmd/test-1.jpeg", True),
|
||||||
|
("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
|
||||||
|
("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
|
||||||
|
("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_local_media_file(media_path, image_url, success,):
|
||||||
|
server = ServerPreset.tinygemma3()
|
||||||
|
server.media_path = media_path
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": 1,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "test"},
|
||||||
|
{"type": "image_url", "image_url": {
|
||||||
|
"url": image_url,
|
||||||
|
}},
|
||||||
|
]},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
if success:
|
||||||
|
assert res.status_code == 200
|
||||||
|
else:
|
||||||
|
assert res.status_code == 400
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,7 @@ class ServerProcess:
|
||||||
chat_template_file: str | None = None
|
chat_template_file: str | None = None
|
||||||
server_path: str | None = None
|
server_path: str | None = None
|
||||||
mmproj_url: str | None = None
|
mmproj_url: str | None = None
|
||||||
|
media_path: str | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
|
@ -217,6 +218,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||||
if self.mmproj_url:
|
if self.mmproj_url:
|
||||||
server_args.extend(["--mmproj-url", self.mmproj_url])
|
server_args.extend(["--mmproj-url", self.mmproj_url])
|
||||||
|
if self.media_path:
|
||||||
|
server_args.extend(["--media-path", self.media_path])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"tests: starting server with: {' '.join(args)}")
|
print(f"tests: starting server with: {' '.join(args)}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue