llama.cpp/examples/openai/server.py

150 lines
5.8 KiB
Python

# https://gist.github.com/ochafik/a3d4a5b9e52390544b205f37fb5a0df3
# pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic
import json, sys, subprocess, atexit
from pathlib import Path
import time
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
import random
from starlette.responses import StreamingResponse
from typing import Annotated, Optional
import typer
from typeguard import typechecked
def generate_id(prefix):
return f"{prefix}{random.randint(0, 1 << 32)}"
def main(
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost",
port: int = 8080,
cpp_server_endpoint: Optional[str] = None,
cpp_server_host: str = "localhost",
cpp_server_port: Optional[int] = 8081,
):
import uvicorn
metadata = GGUFKeyValues(model)
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
chat_format = ChatFormat.from_gguf(metadata)
# print(chat_format)
if not cpp_server_endpoint:
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
server_process = subprocess.Popen([
"./server", "-m", model,
"--host", cpp_server_host, "--port", f'{cpp_server_port}',
'-ctk', 'q4_0', '-ctv', 'f16',
"-c", f"{2*8192}",
# "-c", f"{context_length}",
], stdout=sys.stderr)
atexit.register(server_process.kill)
cpp_server_endpoint = f"http://{cpp_server_host}:{cpp_server_port}"
app = FastAPI()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
headers = {
"Content-Type": "application/json",
}
if (auth := request.headers.get("Authorization")):
headers["Authorization"] = auth
if chat_request.response_format is not None:
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
response_schema = chat_request.response_format.json_schema or {}
else:
response_schema = None
messages = chat_request.messages
if chat_request.tools:
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
# TODO: Test whether the template supports formatting tool_calls
prompt = chat_format.render(messages, add_generation_prompt=True)
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
data = LlamaCppServerCompletionRequest(
**{
k: v
for k, v in chat_request.model_dump().items()
if k not in (
"prompt",
"tools",
"messages",
"response_format",
)
},
prompt=prompt,
grammar=grammar,
).model_dump()
sys.stderr.write(json.dumps(data, indent=2) + "\n")
async with httpx.AsyncClient() as client:
response = await client.post(
f"{cpp_server_endpoint}/completions",
json=data,
headers=headers,
timeout=None)
if chat_request.stream:
# TODO: Remove suffix from streamed response using partial parser.
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
else:
result = response.json()
sys.stderr.write("# RESULT:\n\n" + json.dumps(result, indent=2) + "\n\n")
if 'content' not in result:
# print(json.dumps(result, indent=2))
return JSONResponse(result)
# print(json.dumps(result.get('content'), indent=2))
message = parser(result["content"])
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
prompt_tokens=result['timings']['prompt_n']
completion_tokens=result['timings']['predicted_n']
return JSONResponse(ChatCompletionResponse(
id=generate_id('chatcmpl-'),
object="chat.completion",
created=int(time.time()),
model=chat_request.model,
choices=[Choice(
index=0,
message=message,
finish_reason="stop" if message.tool_calls is None else "tool_calls",
)],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
system_fingerprint='...'
).model_dump())
async def generate_chunks(response):
async for chunk in response.aiter_bytes():
yield chunk
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
typer.run(main)