add test
This commit is contained in:
parent
5423d42a35
commit
0ef3b61e82
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.router()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,success",
|
||||
[
|
||||
("ggml-org/tinygemma3-GGUF:Q8_0", True),
|
||||
("non-existent/model", False),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(model: str, success: bool):
|
||||
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
|
||||
global server
|
||||
server.start()
|
||||
content = ""
|
||||
ex: ServerError | None = None
|
||||
try:
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": 16,
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
"stream": True,
|
||||
})
|
||||
for data in res:
|
||||
if data["choices"]:
|
||||
choice = data["choices"][0]
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert "content" not in choice["delta"]
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"] or ''
|
||||
except ServerError as e:
|
||||
ex = e
|
||||
|
||||
if success:
|
||||
assert ex is None
|
||||
assert len(content) > 0
|
||||
else:
|
||||
assert ex is not None
|
||||
assert content == ""
|
||||
|
|
@ -46,7 +46,7 @@ class ServerProcess:
|
|||
debug: bool = False
|
||||
server_port: int = 8080
|
||||
server_host: str = "127.0.0.1"
|
||||
model_hf_repo: str = "ggml-org/models"
|
||||
model_hf_repo: str | None = "ggml-org/models"
|
||||
model_hf_file: str | None = "tinyllamas/stories260K.gguf"
|
||||
model_alias: str = "tinyllama-2"
|
||||
temperature: float = 0.8
|
||||
|
|
@ -519,9 +519,8 @@ class ServerPreset:
|
|||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
# mmproj is already provided by HF registry API
|
||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
|
||||
server.model_hf_file = "tinygemma3-Q8_0.gguf"
|
||||
server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf"
|
||||
server.model_hf_file = None
|
||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
server.model_alias = "tinygemma3"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 32
|
||||
|
|
@ -529,6 +528,21 @@ class ServerPreset:
|
|||
server.n_predict = 4
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def router() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
# router server has no models
|
||||
server.model_file = None
|
||||
server.model_alias = None
|
||||
server.model_hf_repo = None
|
||||
server.model_hf_file = None
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 16
|
||||
server.n_slots = 1
|
||||
server.n_predict = 16
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
|
||||
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue