diff --git a/common/http.h b/common/http.h index e8ed56f952..d3daccd6bf 100644 --- a/common/http.h +++ b/common/http.h @@ -7,6 +7,7 @@ struct common_http_url { std::string user; std::string password; std::string host; + int port; std::string path; }; @@ -47,6 +48,20 @@ static common_http_url common_http_parse_url(const std::string & url) { parts.host = rest; parts.path = "/"; } + + auto colon_pos = parts.host.find(':'); + + if (colon_pos != std::string::npos) { + parts.port = std::stoi(parts.host.substr(colon_pos + 1)); + parts.host = parts.host.substr(0, colon_pos); + } else if (parts.scheme == "http") { + parts.port = 80; + } else if (parts.scheme == "https") { + parts.port = 443; + } else { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + return parts; } @@ -68,7 +83,7 @@ static std::pair common_http_client(const std: } #endif - httplib::Client cli(parts.scheme + "://" + parts.host); + httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port)); if (!parts.user.empty()) { cli.set_basic_auth(parts.user, parts.password); diff --git a/tools/server/server-cors-proxy.h b/tools/server/server-cors-proxy.h index bca50b53df..c412d4c252 100644 --- a/tools/server/server-cors-proxy.h +++ b/tools/server/server-cors-proxy.h @@ -30,12 +30,13 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme); } - SRV_INF("proxying %s request to %s://%s%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.path.c_str()); + SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str()); auto proxy = std::make_unique( method, + parsed_url.scheme, parsed_url.host, - parsed_url.scheme == "http" ? 80 : 443, + parsed_url.port, parsed_url.path, req.headers, req.body, diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 5f87ba9a21..c13d48a363 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -783,6 +783,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co } auto proxy = std::make_unique( method, + "http", CHILD_ADDR, meta->port, proxy_path, @@ -1079,6 +1080,7 @@ static bool should_strip_proxy_header(const std::string & header_name) { server_http_proxy::server_http_proxy( const std::string & method, + const std::string & scheme, const std::string & host, int port, const std::string & path, @@ -1092,7 +1094,7 @@ server_http_proxy::server_http_proxy( auto cli = std::make_shared(host, port); auto pipe = std::make_shared>(); - if (port == 443) { + if (scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT cli.reset(new httplib::SSLClient(host, port)); #else diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 78abc8d72a..2b392f299a 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -180,6 +180,7 @@ struct server_http_proxy : server_http_res { std::function cleanup = nullptr; public: server_http_proxy(const std::string & method, + const std::string & scheme, const std::string & host, int port, const std::string & path, diff --git a/tools/server/tests/unit/test_proxy.py b/tools/server/tests/unit/test_proxy.py new file mode 100644 index 0000000000..b7c3326187 --- /dev/null +++ b/tools/server/tests/unit/test_proxy.py @@ -0,0 +1,41 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_mcp_no_proxy(): + global server + server.webui_mcp_proxy = False + server.start() + + res = server.make_request("GET", "/cors-proxy") + assert res.status_code == 404 + + +def test_mcp_proxy(): + global server + server.webui_mcp_proxy = True + server.start() + + url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com" + res = requests.get(url) + assert res.status_code == 200 + assert "Example Domain" in res.text + + +def test_mcp_proxy_custom_port(): + global server + server.webui_mcp_proxy = True + server.start() + + # try getting the server's models API via the proxy + res = server.make_request("GET", f"/cors-proxy?url=http://{server.server_host}:{server.server_port}/models") + assert res.status_code == 200 + assert "data" in res.body diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 5002999d9b..db357d876b 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -102,6 +102,7 @@ class ServerProcess: mmproj_url: str | None = None media_path: str | None = None sleep_idle_seconds: int | None = None + webui_mcp_proxy: bool = False # session variables process: subprocess.Popen | None = None @@ -236,6 +237,8 @@ class ServerProcess: server_args.extend(["--media-path", self.media_path]) if self.sleep_idle_seconds is not None: server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) + if self.webui_mcp_proxy: + server_args.append("--webui-mcp-proxy") args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}")