llama.cpp/scripts/server-test-model.py

203 lines
7.8 KiB
Python

import argparse
import json
import requests
import logging
import sys
handler = logging.StreamHandler(sys.stdout)
handler.terminator = "" # ← no newline
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[handler])
logger = logging.getLogger("server-test-model")
def run_query(url, messages, tools=None, stream=False, tool_choice=None):
payload = {
"messages": messages,
"stream": stream,
"max_tokens": 5000,
}
if tools:
payload["tools"] = tools
if tool_choice:
payload["tool_choice"] = tool_choice
try:
response = requests.post(url, json=payload, stream=stream)
response.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response is not None:
logger.info(f"Response error: {e} for {e.response.content}\n")
else:
logger.info(f"Error connecting to server: {e}\n")
return None
full_content = ""
reasoning_content = ""
tool_calls = []
if stream:
logger.info(f"--- Streaming response (Tools: {bool(tools)}) ---\n")
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data: "):
data_str = decoded_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
# Content
content_chunk = delta.get("content", "")
if content_chunk:
full_content += content_chunk
logger.info(content_chunk)
# Reasoning
reasoning_chunk = delta.get("reasoning_content", "")
if reasoning_chunk:
reasoning_content += reasoning_chunk
logger.info(f"\x1B[3m{reasoning_chunk}\x1B[0m")
# Tool calls
if "tool_calls" in delta:
for tc in delta["tool_calls"]:
index = tc.get("index")
if index is not None:
while len(tool_calls) <= index:
# Using "function" as type default but could be flexible
tool_calls.append(
{
"id": "",
"type": "function",
"function": {
"name": "",
"arguments": "",
},
}
)
if "id" in tc:
tool_calls[index]["id"] += tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_calls[index]["function"][
"name"
] += tc["function"]["name"]
if "arguments" in tc["function"]:
tool_calls[index]["function"][
"arguments"
] += tc["function"]["arguments"]
except json.JSONDecodeError:
logger.info(f"Failed to decode JSON: {data_str}\n")
logger.info("\n--- End of Stream ---\n")
else:
logger.info(f"--- Non-streaming response (Tools: {bool(tools)}) ---\n")
data = response.json()
if "choices" in data and len(data["choices"]) > 0:
message = data["choices"][0].get("message", {})
full_content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
tool_calls = message.get("tool_calls", [])
logger.info(full_content)
logger.info("--- End of Response ---\n")
return {
"content": full_content,
"reasoning_content": reasoning_content,
"tool_calls": tool_calls,
}
def test_chat(url, stream):
logger.info(f"\n=== Testing Chat (Stream={stream}) ===\n")
messages = [{"role": "user", "content": "What is the capital of France?"}]
result = run_query(url, messages, stream=stream)
if result:
if result["content"]:
logger.info("PASS: Output received.\n")
else:
logger.info("WARN: No content received (valid if strict tool call, but unexpected here).\n")
if result.get("reasoning_content"):
logger.info(f"INFO: Reasoning content detected ({len(result['reasoning_content'])} chars).\n")
else:
logger.info("INFO: No reasoning content detected (Standard model behavior).\n")
else:
logger.info("FAIL: No result.\n")
def test_tool_call(url, stream):
logger.info(f"\n=== Testing Tool Call (Stream={stream}) ===\n")
messages = [
{
"role": "user",
"content": "What is the weather in London? Please use the get_weather tool.",
}
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
result = run_query(url, messages, tools=tools, tool_choice="auto", stream=stream)
if result:
tcs = result.get("tool_calls")
if tcs and len(tcs) > 0:
logger.info("PASS: Tool calls detected.")
for tc in tcs:
func = tc.get("function", {})
logger.info(f" Tool: {func.get('name')}, Args: {func.get('arguments')}\n")
else:
logger.info(f"FAIL: No tool calls. Content: {result['content']}\n")
if result.get("reasoning_content"):
logger.info(
f"INFO: Reasoning content detected during tool call ({len(result['reasoning_content'])} chars).\n"
)
else:
logger.info("FAIL: Query failed.\n")
def main():
parser = argparse.ArgumentParser(description="Test llama-server functionality.")
parser.add_argument("--host", default="localhost", help="Server host")
parser.add_argument("--port", default=8080, type=int, help="Server port")
args = parser.parse_args()
base_url = f"http://{args.host}:{args.port}/v1/chat/completions"
logger.info(f"Testing server at {base_url}\n")
# Non-streaming tests
test_chat(base_url, stream=False)
test_tool_call(base_url, stream=False)
# Streaming tests
test_chat(base_url, stream=True)
test_tool_call(base_url, stream=True)
if __name__ == "__main__":
main()