150 lines
5.8 KiB
Python
150 lines
5.8 KiB
Python
# https://gist.github.com/ochafik/a3d4a5b9e52390544b205f37fb5a0df3
|
|
# pip install "fastapi[all]" "uvicorn[all]" sse-starlette jsonargparse jinja2 pydantic
|
|
|
|
import json, sys, subprocess, atexit
|
|
from pathlib import Path
|
|
import time
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from examples.openai.llama_cpp_server_api import LlamaCppServerCompletionRequest
|
|
from examples.openai.gguf_kvs import GGUFKeyValues, Keys
|
|
from examples.openai.api import ChatCompletionResponse, Choice, Message, ChatCompletionRequest, Usage
|
|
from examples.openai.prompting import ChatFormat, make_grammar, make_tools_prompt
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
import httpx
|
|
import random
|
|
from starlette.responses import StreamingResponse
|
|
from typing import Annotated, Optional
|
|
import typer
|
|
from typeguard import typechecked
|
|
|
|
def generate_id(prefix):
|
|
return f"{prefix}{random.randint(0, 1 << 32)}"
|
|
|
|
def main(
|
|
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
|
# model: Path = Path("/Users/ochafik/AI/Models/Hermes-2-Pro-Mistral-7B.Q8_0.gguf"),
|
|
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
|
host: str = "localhost",
|
|
port: int = 8080,
|
|
cpp_server_endpoint: Optional[str] = None,
|
|
cpp_server_host: str = "localhost",
|
|
cpp_server_port: Optional[int] = 8081,
|
|
):
|
|
import uvicorn
|
|
|
|
metadata = GGUFKeyValues(model)
|
|
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
|
chat_format = ChatFormat.from_gguf(metadata)
|
|
# print(chat_format)
|
|
|
|
if not cpp_server_endpoint:
|
|
sys.stderr.write(f"# Starting C++ server with model {model} on {cpp_server_host}:{cpp_server_port}\n")
|
|
server_process = subprocess.Popen([
|
|
"./server", "-m", model,
|
|
"--host", cpp_server_host, "--port", f'{cpp_server_port}',
|
|
'-ctk', 'q4_0', '-ctv', 'f16',
|
|
"-c", f"{2*8192}",
|
|
# "-c", f"{context_length}",
|
|
], stdout=sys.stderr)
|
|
atexit.register(server_process.kill)
|
|
cpp_server_endpoint = f"http://{cpp_server_host}:{cpp_server_port}"
|
|
|
|
app = FastAPI()
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request, chat_request: ChatCompletionRequest):
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if (auth := request.headers.get("Authorization")):
|
|
headers["Authorization"] = auth
|
|
|
|
if chat_request.response_format is not None:
|
|
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
|
|
response_schema = chat_request.response_format.json_schema or {}
|
|
else:
|
|
response_schema = None
|
|
|
|
messages = chat_request.messages
|
|
if chat_request.tools:
|
|
messages = chat_format.add_system_prompt(messages, make_tools_prompt(chat_format, chat_request.tools))
|
|
|
|
(grammar, parser) = make_grammar(chat_format, chat_request.tools, response_schema)
|
|
|
|
# TODO: Test whether the template supports formatting tool_calls
|
|
|
|
prompt = chat_format.render(messages, add_generation_prompt=True)
|
|
|
|
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
|
sys.stderr.write(f'\n# GRAMMAR:\n\n{grammar}\n\n')
|
|
|
|
data = LlamaCppServerCompletionRequest(
|
|
**{
|
|
k: v
|
|
for k, v in chat_request.model_dump().items()
|
|
if k not in (
|
|
"prompt",
|
|
"tools",
|
|
"messages",
|
|
"response_format",
|
|
)
|
|
},
|
|
prompt=prompt,
|
|
grammar=grammar,
|
|
).model_dump()
|
|
sys.stderr.write(json.dumps(data, indent=2) + "\n")
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{cpp_server_endpoint}/completions",
|
|
json=data,
|
|
headers=headers,
|
|
timeout=None)
|
|
|
|
if chat_request.stream:
|
|
# TODO: Remove suffix from streamed response using partial parser.
|
|
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"
|
|
return StreamingResponse(generate_chunks(response), media_type="text/event-stream")
|
|
else:
|
|
result = response.json()
|
|
sys.stderr.write("# RESULT:\n\n" + json.dumps(result, indent=2) + "\n\n")
|
|
if 'content' not in result:
|
|
# print(json.dumps(result, indent=2))
|
|
return JSONResponse(result)
|
|
|
|
# print(json.dumps(result.get('content'), indent=2))
|
|
message = parser(result["content"])
|
|
assert message is not None, f"Failed to parse response:\n{response.text}\n\n"
|
|
|
|
prompt_tokens=result['timings']['prompt_n']
|
|
completion_tokens=result['timings']['predicted_n']
|
|
return JSONResponse(ChatCompletionResponse(
|
|
id=generate_id('chatcmpl-'),
|
|
object="chat.completion",
|
|
created=int(time.time()),
|
|
model=chat_request.model,
|
|
choices=[Choice(
|
|
index=0,
|
|
message=message,
|
|
finish_reason="stop" if message.tool_calls is None else "tool_calls",
|
|
)],
|
|
usage=Usage(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
),
|
|
system_fingerprint='...'
|
|
).model_dump())
|
|
|
|
async def generate_chunks(response):
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
if __name__ == "__main__":
|
|
typer.run(main)
|