agent/openai:nits

This commit is contained in:
ochafik 2024-03-29 17:00:53 +00:00
parent ce2fb0155f
commit ea34bd3e5c
10 changed files with 72 additions and 145 deletions

View File

@ -76,7 +76,7 @@ This example relies on the new [OpenAI compatibility server](../openai).
agent.py → examples.openai → server.cpp agent.py → examples.openai → server.cpp
→ safe_tools.py → safe_tools.py
→ ( run_sandboxed_tools.sh : Docker → fastify.py ) → unsafe_tools.py → code interpreter, etc... → ( run_sandboxed_tools.sh : Docker → fastify.py ) → unsafe_tools.py → code interpreter, etc...
``` ```
The agent can use tools written in Python, or (soon) exposed under OpenAPI endpoints. Only has standard Python deps (e.g. no langchain) The agent can use tools written in Python, or (soon) exposed under OpenAPI endpoints. Only has standard Python deps (e.g. no langchain)

View File

@ -128,7 +128,7 @@ def main(
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
std_tools: Optional[bool] = False, std_tools: Optional[bool] = False,
auth: Optional[str] = None, auth: Optional[str] = None,
allow_parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = True,
verbose: bool = False, verbose: bool = False,
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf", model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
@ -174,14 +174,14 @@ def main(
"python", "-m", "examples.openai.server", "python", "-m", "examples.openai.server",
"--model", model, "--model", model,
*(['--verbose'] if verbose else []), *(['--verbose'] if verbose else []),
*(['--allow-parallel-calls'] if allow_parallel_calls else []), *(['--parallel-calls'] if parallel_calls else []),
*(['--context-length={context_length}'] if context_length else []), *(['--context-length={context_length}'] if context_length else []),
*([]) *([])
] ]
server_process = subprocess.Popen(cmd, stdout=sys.stderr) server_process = subprocess.Popen(cmd, stdout=sys.stderr)
atexit.register(server_process.kill) atexit.register(server_process.kill)
sleep(5) sleep(5)
tool_functions = [] tool_functions = []
types = {} types = {}
for f in tools: for f in tools:
@ -195,7 +195,7 @@ def main(
if std_tools: if std_tools:
tool_functions.extend(collect_functions(StandardTools)) tool_functions.extend(collect_functions(StandardTools))
response_model = None#str response_model = None#str
if format: if format:
if format in types: if format in types:
@ -207,8 +207,8 @@ def main(
response_model = json.loads(format) response_model = json.loads(format)
except: except:
response_model = eval(format) response_model = eval(format)
result = completion_with_tool_usage( result = completion_with_tool_usage(
model="...", model="...",
endpoint=endpoint, endpoint=endpoint,

View File

@ -41,4 +41,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
if __name__ == '__main__': if __name__ == '__main__':
typer.run(main) typer.run(main)

View File

@ -11,7 +11,7 @@ script="$( realpath "$1" )"
script_folder="$(dirname "$script")" script_folder="$(dirname "$script")"
shift 1 shift 1
function cleanup { function cleanup {
rm -rf "$BUILD_DIR" rm -rf "$BUILD_DIR"
echo "Deleted $BUILD_DIR" echo "Deleted $BUILD_DIR"
} }

View File

@ -1,15 +1,11 @@
import atexit
from datetime import date from datetime import date
import datetime import datetime
from pydantic import BaseModel
import subprocess import subprocess
import sys import sys
from time import sleep
import time import time
import typer import typer
from pydantic import BaseModel, Json, TypeAdapter from typing import Union, Optional
from annotated_types import MinLen
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
import json, requests
class Duration(BaseModel): class Duration(BaseModel):
seconds: Optional[int] = None seconds: Optional[int] = None
@ -50,7 +46,7 @@ class WaitForDate(BaseModel):
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n") sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
time.sleep(days * 86400 + seconds) time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n") sys.stderr.write(f"Reached the target date: {self.until}\n")
class StandardTools: class StandardTools:
@ -61,7 +57,7 @@ class StandardTools:
This allows getting additional information, requesting disambiguation, etc. This allows getting additional information, requesting disambiguation, etc.
''' '''
return typer.prompt(question) return typer.prompt(question)
@staticmethod @staticmethod
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None: def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
''' '''
@ -69,7 +65,7 @@ class StandardTools:
This can be used to wait for a specific duration or until a specific date. This can be used to wait for a specific duration or until a specific date.
''' '''
return _for() return _for()
@staticmethod @staticmethod
def say_out_loud(something: str) -> str: def say_out_loud(something: str) -> str:
""" """

View File

@ -34,7 +34,7 @@ The new [examples/openai/server.py](./server.py):
} }
// Where T is the output JSON schema, or 'any' // Where T is the output JSON schema, or 'any'
``` ```
- Option to publicise schemas to models as TypeScript signatures (as for Functionary) or JSON schema. - Option to publicise schemas to models as TypeScript signatures (as for Functionary) or JSON schema.
- Supports models that require user/assistant alternance (like Mixtral Instruct) by merging system messages into user messages. - Supports models that require user/assistant alternance (like Mixtral Instruct) by merging system messages into user messages.
@ -175,7 +175,7 @@ curl http://localhost:8080/v1/chat/completions \
- Evaluate options for session caching - Evaluate options for session caching
- Pass session id & store / read from file? - Pass session id & store / read from file?
- Support parent session ids for trees of thought? - Support parent session ids for trees of thought?
- Support precaching long prompts from CLI / read session files? - Support precaching long prompts from CLI / read session files?
@ -186,4 +186,4 @@ curl http://localhost:8080/v1/chat/completions \
- Remove non-Python json-schema-to-grammar versions - Remove non-Python json-schema-to-grammar versions
- Reach out to frameworks to advertise new option. - Reach out to frameworks to advertise new option.

View File

@ -1,28 +1,12 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Json from pydantic import Json
class LlamaCppServerCompletionRequest(BaseModel): from examples.openai.api import LlamaCppParams
class LlamaCppServerCompletionRequest(LlamaCppParams):
prompt: str prompt: str
stream: Optional[bool] = None stream: Optional[bool] = None
cache_prompt: Optional[bool] = None cache_prompt: Optional[bool] = None
n_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
tfs_z: Optional[float] = None
typical_p: Optional[float] = None
temperature: Optional[float] = None
dynatemp_range: Optional[float] = None
dynatemp_exponent: Optional[float] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
mirostat: Optional[bool] = None
mirostat_tau: Optional[float] = None
mirostat_eta: Optional[float] = None
penalize_nl: Optional[bool] = None
n_keep: Optional[int] = None
seed: Optional[int] = None
grammar: Optional[str] = None grammar: Optional[str] = None
json_schema: Optional[Json] = None json_schema: Optional[Json] = None

View File

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import wraps
import jinja2 import jinja2
import json import json
from pathlib import Path from pathlib import Path
import random import random
import re import re
import sys import sys
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
# from typeguard import typechecked
from examples.json_schema_to_grammar import SchemaConverter from examples.json_schema_to_grammar import SchemaConverter
from examples.openai.api import Tool, Message, FunctionCall, ToolCall from examples.openai.api import Tool, Message, FunctionCall, ToolCall
@ -55,7 +53,7 @@ class ChatTemplate(BaseModel):
@property @property
def tool_style(self) -> 'ToolsPromptStyle': def tool_style(self) -> 'ToolsPromptStyle':
return self._tool_style return self._tool_style
def __init__(self, template: str, eos_token: str, bos_token: str): def __init__(self, template: str, eos_token: str, bos_token: str):
super().__init__(template=template super().__init__(template=template
) )
@ -75,7 +73,7 @@ class ChatTemplate(BaseModel):
# self._tool_style = ToolsPromptStyle.TOOLS_MISTRAL # self._tool_style = ToolsPromptStyle.TOOLS_MISTRAL
# TODO: Test whether the template supports formatting tool_calls # TODO: Test whether the template supports formatting tool_calls
delimiter = '<%$[SAMPLE]$%>' delimiter = '<%$[SAMPLE]$%>'
user_msg = Message(role="user", content="Hey") user_msg = Message(role="user", content="Hey")
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip() empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
@ -112,7 +110,7 @@ class ChatTemplate(BaseModel):
def from_gguf(metadata: GGUFKeyValues): def from_gguf(metadata: GGUFKeyValues):
if Keys.Tokenizer.CHAT_TEMPLATE not in metadata: if Keys.Tokenizer.CHAT_TEMPLATE not in metadata:
raise NotImplementedError(f'Only supporting models with {Keys.Tokenizer.CHAT_TEMPLATE} entry in their GGUF key-values (TODO: add default template, maybe pick llama2\'s?)') raise NotImplementedError(f'Only supporting models with {Keys.Tokenizer.CHAT_TEMPLATE} entry in their GGUF key-values (TODO: add default template, maybe pick llama2\'s?)')
tokens = metadata[Keys.Tokenizer.LIST] tokens = metadata[Keys.Tokenizer.LIST]
return ChatTemplate( return ChatTemplate(
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE], template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
eos_token = tokenizer.eos_token) eos_token = tokenizer.eos_token)
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False): def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages): if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
new_messages=[] new_messages=[]
i = 0 i = 0
@ -161,8 +157,7 @@ class ChatTemplate(BaseModel):
i += 1 i += 1
# print(f'new_messages={json.dumps(new_messages, indent=2)}') # print(f'new_messages={json.dumps(new_messages, indent=2)}')
messages = new_messages messages = new_messages
# print(f'messages={messages}')
result = self._template.render( result = self._template.render(
messages=messages, messages=messages,
eos_token=self._eos_token, eos_token=self._eos_token,
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
raise_exception=raise_exception, raise_exception=raise_exception,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
) )
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
return result return result
class ChatHandlerArgs(BaseModel): class ChatHandlerArgs(BaseModel):
@ -192,7 +186,7 @@ class NoToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs): def __init__(self, args: ChatHandlerArgs):
super().__init__(args) super().__init__(args)
assert not args.tools assert not args.tools
if args.response_schema: if args.response_schema:
self.output_format_prompt = Message( self.output_format_prompt = Message(
role="system", role="system",
@ -206,21 +200,20 @@ class NoToolsChatHandler(ChatHandler):
self.output_format_prompt = None self.output_format_prompt = None
self.grammar = None self.grammar = None
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
return Message(role="assistant", content=s) return Message(role="assistant", content=s)
class ToolCallTagsChatHandler(ChatHandler): class ToolCallTagsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
super().__init__(args) super().__init__(args)
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [] tool_rules = []
for tool in self.args.tools: for tool in self.args.tools:
parameters_schema = tool.function.parameters parameters_schema = tool.function.parameters
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name) parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
tool_rules.append(converter.visit( tool_rules.append(converter.visit(
dict( dict(
type="object", type="object",
@ -245,7 +238,7 @@ class ToolCallTagsChatHandler(ChatHandler):
format_literal("<tool_call>") + " space (" + format_literal("<tool_call>") + " space (" +
' | '.join(tool_rules) + ' | '.join(tool_rules) +
") space " + format_literal("</tool_call>"))# + ' space') ") space " + format_literal("</tool_call>"))# + ' space')
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now. # Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call> # So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]') content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
converter._add_rule( converter._add_rule(
'root', 'root',
# tool_call_rule) # tool_call_rule)
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \ f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
else f'{content_rule}* {tool_call_rule}?') else f'{content_rule}* {tool_call_rule}?')
self.grammar = converter.format_grammar() self.grammar = converter.format_grammar()
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
# # OR a tool-call message respecting the schema of any of the tools
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
# converter._format_literal("<tool_call>") + " (" +
# ' | '.join(tool_rules) +
# ") " + converter._format_literal("</tool_call>") +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
@ -294,21 +275,14 @@ class ToolCallTagsChatHandler(ChatHandler):
ToolCall( ToolCall(
id=gen_callid(), id=gen_callid(),
function=FunctionCall(**fc))) function=FunctionCall(**fc)))
content = '\n'.join(content).strip() content = '\n'.join(content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
# return None
# else:
# return Message(role="assistant", content=s)
class TemplatedToolsChatHandler(ToolCallTagsChatHandler): class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True): def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls) super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
assert '{tools}' in template, 'Template must contain "{tools}"' assert '{tools}' in template, 'Template must contain "{tools}"'
self.output_format_prompt = Message( self.output_format_prompt = Message(
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
) )
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler): class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls) super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling # Hackily import https://github.com/NousResearch/Hermes-Function-Calling
path = str(Path(__file__).parent / "hermes_function_calling") path = str(Path(__file__).parent / "hermes_function_calling")
@ -330,15 +304,15 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
from examples.openai.hermes_function_calling.prompter import PromptManager from examples.openai.hermes_function_calling.prompter import PromptManager
except ImportError: except ImportError:
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`") raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools]) prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
assert len(prompt) == 1 and prompt[0]["role"] == "system" assert len(prompt) == 1 and prompt[0]["role"] == "system"
self.output_format_prompt = Message(**prompt[0]) self.output_format_prompt = Message(**prompt[0])
class FunctionaryToolsChatHandler(ChatHandler): class FunctionaryToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args) super().__init__(args)
# Only allowing a single tool call at a time for now. # Only allowing a single tool call at a time for now.
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal # Note that if there were more, they'd be separated by a '<|from|>assistant' literal
@ -347,7 +321,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
content= '// Supported function definitions that should be called when necessary.\n' + content= '// Supported function definitions that should be called when necessary.\n' +
_tools_typescript_signatures(args.tools) _tools_typescript_signatures(args.tools)
) )
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
tool_rules = [ tool_rules = [
converter._add_rule( converter._add_rule(
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' + converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
converter._format_literal('\n')) converter._format_literal('\n'))
# converter.visit(
# dict(
# type="object",
# properties=dict(
# name=dict(const=tool.function.name),
# arguments=tool.function.parameters,
# ),
# required=['name', 'arguments']
# ),
# f'{tool.function.name}-tool-call'
# )
for i, tool in enumerate(self.args.tools) for i, tool in enumerate(self.args.tools)
] ]
@ -378,33 +341,18 @@ class FunctionaryToolsChatHandler(ChatHandler):
tool_call_without_start_rule = converter._add_rule( tool_call_without_start_rule = converter._add_rule(
'tool_call_without_start', 'tool_call_without_start',
' | '.join(tool_rules)) ' | '.join(tool_rules))
# + ' ' +
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}') tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
converter._add_rule( converter._add_rule(
'root', 'root',
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | ' f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \ f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}') else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
self.grammar = converter.format_grammar() self.grammar = converter.format_grammar()
# converter._add_rule(
# "root",
# converter._format_literal(prefix) + " (" +
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
# (' | '.join(
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
# converter.visit(tool.function.parameters, tool.function.name + '-args')
# for tool in tools
# )) +
# ") " +
# ")") # + converter._format_literal(suffix))
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
parts = _recipient_content_re.split(s) parts = _recipient_content_re.split(s)
if len(parts) == 1: if len(parts) == 1:
return Message(role="assistant", content=s) return Message(role="assistant", content=s)
@ -426,14 +374,14 @@ class FunctionaryToolsChatHandler(ChatHandler):
ToolCall( ToolCall(
id=gen_callid(), id=gen_callid(),
function=FunctionCall(name=recipient, arguments=arguments))) function=FunctionCall(name=recipient, arguments=arguments)))
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}' assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
content = '\n'.join(text_content).strip() content = '\n'.join(text_content).strip()
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None) return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls): def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
# "const": "tool_calls" # "const": "tool_calls"
# }, # },
"tool_calls": { "tool_calls": {
"prefixItems": tool_call_schema if allow_parallel_calls \ "prefixItems": tool_call_schema if parallel_calls \
else [tool_call_schema], else [tool_call_schema],
} }
}, },
@ -474,9 +422,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
} }
class BespokeToolsChatHandler(ChatHandler): class BespokeToolsChatHandler(ChatHandler):
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool): def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
super().__init__(args) super().__init__(args)
# args.response_schema = args.response_schema or {} # args.response_schema = args.response_schema or {}
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
for tool in self.args.tools for tool in self.args.tools
] ]
}, },
allow_parallel_calls=allow_parallel_calls, parallel_calls=parallel_calls,
), ),
'', '',
) )
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
}, },
"required": ["name", "arguments"] "required": ["name", "arguments"]
}, },
allow_parallel_calls=allow_parallel_calls, parallel_calls=parallel_calls,
) )
), ),
]) ])
) )
# @typechecked
def parse(self, s: str) -> Optional[Message]: def parse(self, s: str) -> Optional[Message]:
s = self.args.chat_template.strip_suffix(s) s = self.args.chat_template.strip_suffix(s)
try: try:
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''', # 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
]) ])
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler: def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
if not args.tools: if not args.tools:
return NoToolsChatHandler(args) return NoToolsChatHandler(args)
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2: elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False) return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls) return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls) return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO: elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
return Hermes2ProToolsChatHandler(args) return Hermes2ProToolsChatHandler(args)
else: else:

View File

@ -31,7 +31,7 @@ def main(
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None, # model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
host: str = "localhost", host: str = "localhost",
port: int = 8080, port: int = 8080,
allow_parallel_calls: Optional[bool] = False, parallel_calls: Optional[bool] = True,
auth: Optional[str] = None, auth: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
context_length: Optional[int] = None, context_length: Optional[int] = None,
@ -44,13 +44,13 @@ def main(
if endpoint: if endpoint:
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback) chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
else: else:
metadata = GGUFKeyValues(model) metadata = GGUFKeyValues(model)
if not context_length: if not context_length:
context_length = metadata[Keys.LLM.CONTEXT_LENGTH] context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
if Keys.Tokenizer.CHAT_TEMPLATE in metadata: if Keys.Tokenizer.CHAT_TEMPLATE in metadata:
chat_template = ChatTemplate.from_gguf(metadata) chat_template = ChatTemplate.from_gguf(metadata)
else: else:
@ -92,22 +92,22 @@ def main(
chat_handler = get_chat_handler( chat_handler = get_chat_handler(
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools), ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
allow_parallel_calls=allow_parallel_calls parallel_calls=parallel_calls
) )
messages = chat_request.messages messages = chat_request.messages
if chat_handler.output_format_prompt: if chat_handler.output_format_prompt:
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt) messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
prompt = chat_template.render(messages, add_generation_prompt=True) prompt = chat_template.render(messages, add_generation_prompt=True)
if verbose: if verbose:
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n') sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n') # sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n') sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n') sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
data = LlamaCppServerCompletionRequest( data = LlamaCppServerCompletionRequest(
**{ **{
k: v k: v
@ -130,7 +130,7 @@ def main(
json=data, json=data,
headers=headers, headers=headers,
timeout=None) timeout=None)
if chat_request.stream: if chat_request.stream:
# TODO: Remove suffix from streamed response using partial parser. # TODO: Remove suffix from streamed response using partial parser.
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format" assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"

View File

@ -31,7 +31,7 @@ class SchemaToTypeScriptConverter:
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"] [f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
if additional_properties is not None else [] if additional_properties is not None else []
)) + "}" )) + "}"
def visit(self, schema: dict): def visit(self, schema: dict):
def print_constant(v): def print_constant(v):
return json.dumps(v) return json.dumps(v)