* llama-server: add router multi-model tests (#17704) Add 4 test cases for model router: - test_router_unload_model: explicit model unloading - test_router_models_max_evicts_lru: LRU eviction with --models-max - test_router_no_models_autoload: --no-models-autoload flag behavior - test_router_api_key_required: API key authentication Tests use async model loading with polling and graceful skip when insufficient models available for eviction testing. utils.py changes: - Add models_max, models_dir, no_models_autoload attributes to ServerProcess - Handle JSONDecodeError for non-JSON error responses (fallback to text) * llama-server: update test models to new HF repos * add offline * llama-server: fix router LRU eviction test and add preloading Fix eviction test: load 2 models first, verify state, then load 3rd to trigger eviction. Previous logic loaded all 3 at once, causing first model to be evicted before verification could occur. Add module fixture to preload models via ServerPreset.load_all() and mark test presets as offline to use cached models * llama-server: fix split model download on Windows --------- Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1257491047
commit
e7c2cf1356
|
|
@ -65,6 +65,7 @@ def test_server_slots():
|
||||||
|
|
||||||
def test_load_split_model():
|
def test_load_split_model():
|
||||||
global server
|
global server
|
||||||
|
server.offline = False
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
||||||
server.model_alias = "tinyllama-split"
|
server.model_alias = "tinyllama-split"
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ def create_server():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_router_chat_completion_stream(model: str, success: bool):
|
def test_router_chat_completion_stream(model: str, success: bool):
|
||||||
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
|
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
content = ""
|
content = ""
|
||||||
|
|
@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
|
||||||
else:
|
else:
|
||||||
assert ex is not None
|
assert ex is not None
|
||||||
assert content == ""
|
assert content == ""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_status(model_id: str) -> str:
|
||||||
|
res = server.make_request("GET", "/models")
|
||||||
|
assert res.status_code == 200
|
||||||
|
for item in res.body.get("data", []):
|
||||||
|
if item.get("id") == model_id or item.get("model") == model_id:
|
||||||
|
return item["status"]["value"]
|
||||||
|
raise AssertionError(f"Model {model_id} not found in /models response")
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
last_status = None
|
||||||
|
while time.time() < deadline:
|
||||||
|
last_status = _get_model_status(model_id)
|
||||||
|
if last_status in desired:
|
||||||
|
return last_status
|
||||||
|
time.sleep(1)
|
||||||
|
raise AssertionError(
|
||||||
|
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model_and_wait(
|
||||||
|
model_id: str, timeout: int = 60, headers: dict | None = None
|
||||||
|
) -> None:
|
||||||
|
load_res = server.make_request(
|
||||||
|
"POST", "/models/load", data={"model": model_id}, headers=headers
|
||||||
|
)
|
||||||
|
assert load_res.status_code == 200
|
||||||
|
assert isinstance(load_res.body, dict)
|
||||||
|
assert load_res.body.get("success") is True
|
||||||
|
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_unload_model():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||||
|
|
||||||
|
_load_model_and_wait(model_id)
|
||||||
|
|
||||||
|
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
|
||||||
|
assert unload_res.status_code == 200
|
||||||
|
assert unload_res.body.get("success") is True
|
||||||
|
_wait_for_model_status(model_id, {"unloaded"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_models_max_evicts_lru():
|
||||||
|
global server
|
||||||
|
server.models_max = 2
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
candidate_models = [
|
||||||
|
"ggml-org/tinygemma3-GGUF:Q8_0",
|
||||||
|
"ggml-org/test-model-stories260K",
|
||||||
|
"ggml-org/test-model-stories260K-infill",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load only the first 2 models to fill the cache
|
||||||
|
first, second, third = candidate_models[:3]
|
||||||
|
|
||||||
|
_load_model_and_wait(first, timeout=120)
|
||||||
|
_load_model_and_wait(second, timeout=120)
|
||||||
|
|
||||||
|
# Verify both models are loaded
|
||||||
|
assert _get_model_status(first) == "loaded"
|
||||||
|
assert _get_model_status(second) == "loaded"
|
||||||
|
|
||||||
|
# Load the third model - this should trigger LRU eviction of the first model
|
||||||
|
_load_model_and_wait(third, timeout=120)
|
||||||
|
|
||||||
|
# Verify eviction: third is loaded, first was evicted
|
||||||
|
assert _get_model_status(third) == "loaded"
|
||||||
|
assert _get_model_status(first) == "unloaded"
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_no_models_autoload():
|
||||||
|
global server
|
||||||
|
server.no_models_autoload = True
|
||||||
|
server.start()
|
||||||
|
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||||
|
|
||||||
|
res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"max_tokens": 4,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
||||||
|
_load_model_and_wait(model_id)
|
||||||
|
|
||||||
|
success_res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"max_tokens": 4,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert success_res.status_code == 200
|
||||||
|
assert "error" not in success_res.body
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_api_key_required():
|
||||||
|
global server
|
||||||
|
server.api_key = "sk-router-secret"
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||||
|
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
|
||||||
|
|
||||||
|
res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"max_tokens": 4,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert res.status_code == 401
|
||||||
|
assert res.body.get("error", {}).get("type") == "authentication_error"
|
||||||
|
|
||||||
|
_load_model_and_wait(model_id, headers=auth_headers)
|
||||||
|
|
||||||
|
authed = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
headers=auth_headers,
|
||||||
|
data={
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"max_tokens": 4,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert authed.status_code == 200
|
||||||
|
assert "error" not in authed.body
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import subprocess
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
from json import JSONDecodeError
|
||||||
import sys
|
import sys
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
|
@ -83,6 +84,9 @@ class ServerProcess:
|
||||||
pooling: str | None = None
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
models_dir: str | None = None
|
||||||
|
models_max: int | None = None
|
||||||
|
no_models_autoload: bool | None = None
|
||||||
lora_files: List[str] | None = None
|
lora_files: List[str] | None = None
|
||||||
enable_ctx_shift: int | None = False
|
enable_ctx_shift: int | None = False
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
|
|
@ -143,6 +147,10 @@ class ServerProcess:
|
||||||
server_args.extend(["--hf-repo", self.model_hf_repo])
|
server_args.extend(["--hf-repo", self.model_hf_repo])
|
||||||
if self.model_hf_file:
|
if self.model_hf_file:
|
||||||
server_args.extend(["--hf-file", self.model_hf_file])
|
server_args.extend(["--hf-file", self.model_hf_file])
|
||||||
|
if self.models_dir:
|
||||||
|
server_args.extend(["--models-dir", self.models_dir])
|
||||||
|
if self.models_max is not None:
|
||||||
|
server_args.extend(["--models-max", self.models_max])
|
||||||
if self.n_batch:
|
if self.n_batch:
|
||||||
server_args.extend(["--batch-size", self.n_batch])
|
server_args.extend(["--batch-size", self.n_batch])
|
||||||
if self.n_ubatch:
|
if self.n_ubatch:
|
||||||
|
|
@ -204,6 +212,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--draft-min", self.draft_min])
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
if self.no_webui:
|
if self.no_webui:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
|
if self.no_models_autoload:
|
||||||
|
server_args.append("--no-models-autoload")
|
||||||
if self.jinja:
|
if self.jinja:
|
||||||
server_args.append("--jinja")
|
server_args.append("--jinja")
|
||||||
else:
|
else:
|
||||||
|
|
@ -295,7 +305,13 @@ class ServerProcess:
|
||||||
result = ServerResponse()
|
result = ServerResponse()
|
||||||
result.headers = dict(response.headers)
|
result.headers = dict(response.headers)
|
||||||
result.status_code = response.status_code
|
result.status_code = response.status_code
|
||||||
result.body = response.json() if parse_body else None
|
if parse_body:
|
||||||
|
try:
|
||||||
|
result.body = response.json()
|
||||||
|
except JSONDecodeError:
|
||||||
|
result.body = response.text
|
||||||
|
else:
|
||||||
|
result.body = None
|
||||||
print("Response from server", json.dumps(result.body, indent=2))
|
print("Response from server", json.dumps(result.body, indent=2))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -434,8 +450,9 @@ class ServerPreset:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tinyllama2() -> ServerProcess:
|
def tinyllama2() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
server.model_hf_repo = "ggml-org/test-model-stories260K"
|
||||||
|
server.model_hf_file = None
|
||||||
server.model_alias = "tinyllama-2"
|
server.model_alias = "tinyllama-2"
|
||||||
server.n_ctx = 512
|
server.n_ctx = 512
|
||||||
server.n_batch = 32
|
server.n_batch = 32
|
||||||
|
|
@ -479,8 +496,8 @@ class ServerPreset:
|
||||||
def tinyllama_infill() -> ServerProcess:
|
def tinyllama_infill() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
server.offline = True # will be downloaded by load_all()
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
|
||||||
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
server.model_hf_file = None
|
||||||
server.model_alias = "tinyllama-infill"
|
server.model_alias = "tinyllama-infill"
|
||||||
server.n_ctx = 2048
|
server.n_ctx = 2048
|
||||||
server.n_batch = 1024
|
server.n_batch = 1024
|
||||||
|
|
@ -537,6 +554,7 @@ class ServerPreset:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def router() -> ServerProcess:
|
def router() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
# router server has no models
|
# router server has no models
|
||||||
server.model_file = None
|
server.model_file = None
|
||||||
server.model_alias = None
|
server.model_alias = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue