diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py new file mode 100644 index 0000000000..3b1a811c3e --- /dev/null +++ b/tools/server/tests/unit/test_router.py @@ -0,0 +1,50 @@ +import pytest +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.router() + + +@pytest.mark.parametrize( + "model,success", + [ + ("ggml-org/tinygemma3-GGUF:Q8_0", True), + ("non-existent/model", False), + ] +) +def test_chat_completion_stream(model: str, success: bool): + # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server + global server + server.start() + content = "" + ex: ServerError | None = None + try: + res = server.make_stream_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": 16, + "messages": [ + {"role": "user", "content": "hello"}, + ], + "stream": True, + }) + for data in res: + if data["choices"]: + choice = data["choices"][0] + if choice["finish_reason"] in ["stop", "length"]: + assert "content" not in choice["delta"] + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' + except ServerError as e: + ex = e + + if success: + assert ex is None + assert len(content) > 0 + else: + assert ex is not None + assert content == "" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index da703c4c51..a9eec74822 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -46,7 +46,7 @@ class ServerProcess: debug: bool = False server_port: int = 8080 server_host: str = "127.0.0.1" - model_hf_repo: str = "ggml-org/models" + model_hf_repo: str | None = "ggml-org/models" model_hf_file: str | None = "tinyllamas/stories260K.gguf" model_alias: str = "tinyllama-2" temperature: float = 0.8 @@ -519,9 +519,8 @@ class ServerPreset: server = ServerProcess() server.offline = True # will be downloaded by load_all() # mmproj is already provided by HF registry API - server.model_hf_repo = "ggml-org/tinygemma3-GGUF" - server.model_hf_file = "tinygemma3-Q8_0.gguf" - server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_hf_file = None + server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0" server.model_alias = "tinygemma3" server.n_ctx = 1024 server.n_batch = 32 @@ -529,6 +528,21 @@ class ServerPreset: server.n_predict = 4 server.seed = 42 return server + + @staticmethod + def router() -> ServerProcess: + server = ServerProcess() + # router server has no models + server.model_file = None + server.model_alias = None + server.model_hf_repo = None + server.model_hf_file = None + server.n_ctx = 1024 + server.n_batch = 16 + server.n_slots = 1 + server.n_predict = 16 + server.seed = 42 + return server def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: