agent/openai:nits
This commit is contained in:
parent
ce2fb0155f
commit
ea34bd3e5c
|
|
@ -76,7 +76,7 @@ This example relies on the new [OpenAI compatibility server](../openai).
|
||||||
agent.py → examples.openai → server.cpp
|
agent.py → examples.openai → server.cpp
|
||||||
→ safe_tools.py
|
→ safe_tools.py
|
||||||
→ ( run_sandboxed_tools.sh : Docker → fastify.py ) → unsafe_tools.py → code interpreter, etc...
|
→ ( run_sandboxed_tools.sh : Docker → fastify.py ) → unsafe_tools.py → code interpreter, etc...
|
||||||
```
|
```
|
||||||
|
|
||||||
The agent can use tools written in Python, or (soon) exposed under OpenAPI endpoints. Only has standard Python deps (e.g. no langchain)
|
The agent can use tools written in Python, or (soon) exposed under OpenAPI endpoints. Only has standard Python deps (e.g. no langchain)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ def main(
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
std_tools: Optional[bool] = False,
|
std_tools: Optional[bool] = False,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
allow_parallel_calls: Optional[bool] = False,
|
parallel_calls: Optional[bool] = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
|
||||||
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
model: Annotated[Optional[Path], typer.Option("--model", "-m")] = "models/7B/ggml-model-f16.gguf",
|
||||||
|
|
@ -174,14 +174,14 @@ def main(
|
||||||
"python", "-m", "examples.openai.server",
|
"python", "-m", "examples.openai.server",
|
||||||
"--model", model,
|
"--model", model,
|
||||||
*(['--verbose'] if verbose else []),
|
*(['--verbose'] if verbose else []),
|
||||||
*(['--allow-parallel-calls'] if allow_parallel_calls else []),
|
*(['--parallel-calls'] if parallel_calls else []),
|
||||||
*(['--context-length={context_length}'] if context_length else []),
|
*(['--context-length={context_length}'] if context_length else []),
|
||||||
*([])
|
*([])
|
||||||
]
|
]
|
||||||
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
|
||||||
atexit.register(server_process.kill)
|
atexit.register(server_process.kill)
|
||||||
sleep(5)
|
sleep(5)
|
||||||
|
|
||||||
tool_functions = []
|
tool_functions = []
|
||||||
types = {}
|
types = {}
|
||||||
for f in tools:
|
for f in tools:
|
||||||
|
|
@ -195,7 +195,7 @@ def main(
|
||||||
|
|
||||||
if std_tools:
|
if std_tools:
|
||||||
tool_functions.extend(collect_functions(StandardTools))
|
tool_functions.extend(collect_functions(StandardTools))
|
||||||
|
|
||||||
response_model = None#str
|
response_model = None#str
|
||||||
if format:
|
if format:
|
||||||
if format in types:
|
if format in types:
|
||||||
|
|
@ -207,8 +207,8 @@ def main(
|
||||||
response_model = json.loads(format)
|
response_model = json.loads(format)
|
||||||
except:
|
except:
|
||||||
response_model = eval(format)
|
response_model = eval(format)
|
||||||
|
|
||||||
|
|
||||||
result = completion_with_tool_usage(
|
result = completion_with_tool_usage(
|
||||||
model="...",
|
model="...",
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
|
|
||||||
|
|
@ -41,4 +41,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
typer.run(main)
|
typer.run(main)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ script="$( realpath "$1" )"
|
||||||
script_folder="$(dirname "$script")"
|
script_folder="$(dirname "$script")"
|
||||||
shift 1
|
shift 1
|
||||||
|
|
||||||
function cleanup {
|
function cleanup {
|
||||||
rm -rf "$BUILD_DIR"
|
rm -rf "$BUILD_DIR"
|
||||||
echo "Deleted $BUILD_DIR"
|
echo "Deleted $BUILD_DIR"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
import atexit
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
import datetime
|
import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from time import sleep
|
|
||||||
import time
|
import time
|
||||||
import typer
|
import typer
|
||||||
from pydantic import BaseModel, Json, TypeAdapter
|
from typing import Union, Optional
|
||||||
from annotated_types import MinLen
|
|
||||||
from typing import Annotated, Callable, List, Union, Literal, Optional, Type, get_args, get_origin
|
|
||||||
import json, requests
|
|
||||||
|
|
||||||
class Duration(BaseModel):
|
class Duration(BaseModel):
|
||||||
seconds: Optional[int] = None
|
seconds: Optional[int] = None
|
||||||
|
|
@ -50,7 +46,7 @@ class WaitForDate(BaseModel):
|
||||||
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
|
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {d}...\n")
|
||||||
time.sleep(days * 86400 + seconds)
|
time.sleep(days * 86400 + seconds)
|
||||||
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
sys.stderr.write(f"Reached the target date: {self.until}\n")
|
||||||
|
|
||||||
|
|
||||||
class StandardTools:
|
class StandardTools:
|
||||||
|
|
||||||
|
|
@ -61,7 +57,7 @@ class StandardTools:
|
||||||
This allows getting additional information, requesting disambiguation, etc.
|
This allows getting additional information, requesting disambiguation, etc.
|
||||||
'''
|
'''
|
||||||
return typer.prompt(question)
|
return typer.prompt(question)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
|
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
|
||||||
'''
|
'''
|
||||||
|
|
@ -69,7 +65,7 @@ class StandardTools:
|
||||||
This can be used to wait for a specific duration or until a specific date.
|
This can be used to wait for a specific duration or until a specific date.
|
||||||
'''
|
'''
|
||||||
return _for()
|
return _for()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def say_out_loud(something: str) -> str:
|
def say_out_loud(something: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ The new [examples/openai/server.py](./server.py):
|
||||||
}
|
}
|
||||||
// Where T is the output JSON schema, or 'any'
|
// Where T is the output JSON schema, or 'any'
|
||||||
```
|
```
|
||||||
|
|
||||||
- Option to publicise schemas to models as TypeScript signatures (as for Functionary) or JSON schema.
|
- Option to publicise schemas to models as TypeScript signatures (as for Functionary) or JSON schema.
|
||||||
|
|
||||||
- Supports models that require user/assistant alternance (like Mixtral Instruct) by merging system messages into user messages.
|
- Supports models that require user/assistant alternance (like Mixtral Instruct) by merging system messages into user messages.
|
||||||
|
|
@ -175,7 +175,7 @@ curl http://localhost:8080/v1/chat/completions \
|
||||||
- Evaluate options for session caching
|
- Evaluate options for session caching
|
||||||
|
|
||||||
- Pass session id & store / read from file?
|
- Pass session id & store / read from file?
|
||||||
|
|
||||||
- Support parent session ids for trees of thought?
|
- Support parent session ids for trees of thought?
|
||||||
|
|
||||||
- Support precaching long prompts from CLI / read session files?
|
- Support precaching long prompts from CLI / read session files?
|
||||||
|
|
@ -186,4 +186,4 @@ curl http://localhost:8080/v1/chat/completions \
|
||||||
|
|
||||||
- Remove non-Python json-schema-to-grammar versions
|
- Remove non-Python json-schema-to-grammar versions
|
||||||
|
|
||||||
- Reach out to frameworks to advertise new option.
|
- Reach out to frameworks to advertise new option.
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,12 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel, Json
|
from pydantic import Json
|
||||||
|
|
||||||
class LlamaCppServerCompletionRequest(BaseModel):
|
from examples.openai.api import LlamaCppParams
|
||||||
|
|
||||||
|
class LlamaCppServerCompletionRequest(LlamaCppParams):
|
||||||
prompt: str
|
prompt: str
|
||||||
stream: Optional[bool] = None
|
stream: Optional[bool] = None
|
||||||
cache_prompt: Optional[bool] = None
|
cache_prompt: Optional[bool] = None
|
||||||
n_predict: Optional[int] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
min_p: Optional[float] = None
|
|
||||||
tfs_z: Optional[float] = None
|
|
||||||
typical_p: Optional[float] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
dynatemp_range: Optional[float] = None
|
|
||||||
dynatemp_exponent: Optional[float] = None
|
|
||||||
repeat_last_n: Optional[int] = None
|
|
||||||
repeat_penalty: Optional[float] = None
|
|
||||||
frequency_penalty: Optional[float] = None
|
|
||||||
presence_penalty: Optional[float] = None
|
|
||||||
mirostat: Optional[bool] = None
|
|
||||||
mirostat_tau: Optional[float] = None
|
|
||||||
mirostat_eta: Optional[float] = None
|
|
||||||
penalize_nl: Optional[bool] = None
|
|
||||||
n_keep: Optional[int] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
grammar: Optional[str] = None
|
grammar: Optional[str] = None
|
||||||
json_schema: Optional[Json] = None
|
json_schema: Optional[Json] = None
|
||||||
|
|
@ -1,15 +1,13 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Literal, Optional, Tuple, Callable, Union
|
from typing import Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
# from typeguard import typechecked
|
|
||||||
|
|
||||||
from examples.json_schema_to_grammar import SchemaConverter
|
from examples.json_schema_to_grammar import SchemaConverter
|
||||||
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
from examples.openai.api import Tool, Message, FunctionCall, ToolCall
|
||||||
|
|
@ -55,7 +53,7 @@ class ChatTemplate(BaseModel):
|
||||||
@property
|
@property
|
||||||
def tool_style(self) -> 'ToolsPromptStyle':
|
def tool_style(self) -> 'ToolsPromptStyle':
|
||||||
return self._tool_style
|
return self._tool_style
|
||||||
|
|
||||||
def __init__(self, template: str, eos_token: str, bos_token: str):
|
def __init__(self, template: str, eos_token: str, bos_token: str):
|
||||||
super().__init__(template=template
|
super().__init__(template=template
|
||||||
)
|
)
|
||||||
|
|
@ -75,7 +73,7 @@ class ChatTemplate(BaseModel):
|
||||||
# self._tool_style = ToolsPromptStyle.TOOLS_MISTRAL
|
# self._tool_style = ToolsPromptStyle.TOOLS_MISTRAL
|
||||||
|
|
||||||
# TODO: Test whether the template supports formatting tool_calls
|
# TODO: Test whether the template supports formatting tool_calls
|
||||||
|
|
||||||
delimiter = '<%$[SAMPLE]$%>'
|
delimiter = '<%$[SAMPLE]$%>'
|
||||||
user_msg = Message(role="user", content="Hey")
|
user_msg = Message(role="user", content="Hey")
|
||||||
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
|
empty_prompt = self.render([user_msg], add_generation_prompt=True).strip()
|
||||||
|
|
@ -112,7 +110,7 @@ class ChatTemplate(BaseModel):
|
||||||
def from_gguf(metadata: GGUFKeyValues):
|
def from_gguf(metadata: GGUFKeyValues):
|
||||||
if Keys.Tokenizer.CHAT_TEMPLATE not in metadata:
|
if Keys.Tokenizer.CHAT_TEMPLATE not in metadata:
|
||||||
raise NotImplementedError(f'Only supporting models with {Keys.Tokenizer.CHAT_TEMPLATE} entry in their GGUF key-values (TODO: add default template, maybe pick llama2\'s?)')
|
raise NotImplementedError(f'Only supporting models with {Keys.Tokenizer.CHAT_TEMPLATE} entry in their GGUF key-values (TODO: add default template, maybe pick llama2\'s?)')
|
||||||
|
|
||||||
tokens = metadata[Keys.Tokenizer.LIST]
|
tokens = metadata[Keys.Tokenizer.LIST]
|
||||||
return ChatTemplate(
|
return ChatTemplate(
|
||||||
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
template = metadata[Keys.Tokenizer.CHAT_TEMPLATE],
|
||||||
|
|
@ -129,8 +127,6 @@ class ChatTemplate(BaseModel):
|
||||||
eos_token = tokenizer.eos_token)
|
eos_token = tokenizer.eos_token)
|
||||||
|
|
||||||
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
def render(self, messages: list[Message], add_generation_prompt: bool, omit_bos: bool = False):
|
||||||
# sys.stderr.write(f'# strict_user_assistant_alternation={self._strict_user_assistant_alternation}\n')
|
|
||||||
# sys.stderr.write(f'# messages=' + "\n".join(json.dumps(m.model_dump(), indent=2) for m in messages) + '\n')
|
|
||||||
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
if self._strict_user_assistant_alternation and any(m.role not in ('user', 'assistant') for m in messages):
|
||||||
new_messages=[]
|
new_messages=[]
|
||||||
i = 0
|
i = 0
|
||||||
|
|
@ -161,8 +157,7 @@ class ChatTemplate(BaseModel):
|
||||||
i += 1
|
i += 1
|
||||||
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
# print(f'new_messages={json.dumps(new_messages, indent=2)}')
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
# print(f'messages={messages}')
|
|
||||||
|
|
||||||
result = self._template.render(
|
result = self._template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
eos_token=self._eos_token,
|
eos_token=self._eos_token,
|
||||||
|
|
@ -170,7 +165,6 @@ class ChatTemplate(BaseModel):
|
||||||
raise_exception=raise_exception,
|
raise_exception=raise_exception,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
)
|
)
|
||||||
# sys.stderr.write(f'\n# RENDERED:\n\n{result}\n\n')
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class ChatHandlerArgs(BaseModel):
|
class ChatHandlerArgs(BaseModel):
|
||||||
|
|
@ -192,7 +186,7 @@ class NoToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs):
|
def __init__(self, args: ChatHandlerArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
assert not args.tools
|
assert not args.tools
|
||||||
|
|
||||||
if args.response_schema:
|
if args.response_schema:
|
||||||
self.output_format_prompt = Message(
|
self.output_format_prompt = Message(
|
||||||
role="system",
|
role="system",
|
||||||
|
|
@ -206,21 +200,20 @@ class NoToolsChatHandler(ChatHandler):
|
||||||
self.output_format_prompt = None
|
self.output_format_prompt = None
|
||||||
self.grammar = None
|
self.grammar = None
|
||||||
|
|
||||||
# @typechecked
|
|
||||||
def parse(self, s: str) -> Optional[Message]:
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
|
|
||||||
class ToolCallTagsChatHandler(ChatHandler):
|
class ToolCallTagsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, allow_parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, escapes_underscores: bool, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
tool_rules = []
|
tool_rules = []
|
||||||
for tool in self.args.tools:
|
for tool in self.args.tools:
|
||||||
|
|
||||||
parameters_schema = tool.function.parameters
|
parameters_schema = tool.function.parameters
|
||||||
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
|
parameters_schema = converter.resolve_refs(parameters_schema, tool.function.name)
|
||||||
|
|
||||||
tool_rules.append(converter.visit(
|
tool_rules.append(converter.visit(
|
||||||
dict(
|
dict(
|
||||||
type="object",
|
type="object",
|
||||||
|
|
@ -245,7 +238,7 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
format_literal("<tool_call>") + " space (" +
|
format_literal("<tool_call>") + " space (" +
|
||||||
' | '.join(tool_rules) +
|
' | '.join(tool_rules) +
|
||||||
") space " + format_literal("</tool_call>"))# + ' space')
|
") space " + format_literal("</tool_call>"))# + ' space')
|
||||||
|
|
||||||
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
# Ideally we'd want a negative lookahead of /<tool\\?_call>/, but it's just too hard to express in GBNF for now.
|
||||||
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
# So we just over-constrain the content rule to not contain literals dangerously getting close to <tool_call>
|
||||||
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
|
content_rule = converter._add_rule('content', '[^<] | "<" [^t<] | "<t" [^o<]')
|
||||||
|
|
@ -253,22 +246,10 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
'root',
|
'root',
|
||||||
# tool_call_rule)
|
# tool_call_rule)
|
||||||
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if allow_parallel_calls \
|
f'{content_rule}* ({tool_call_rule}+ {content_rule}*)?' if parallel_calls \
|
||||||
else f'{content_rule}* {tool_call_rule}?')
|
else f'{content_rule}* {tool_call_rule}?')
|
||||||
self.grammar = converter.format_grammar()
|
self.grammar = converter.format_grammar()
|
||||||
|
|
||||||
# # Constrain the output to be a non-tool-call message (constrained to a JSON schema or not)
|
|
||||||
# # OR a tool-call message respecting the schema of any of the tools
|
|
||||||
# converter._add_rule(
|
|
||||||
# "root",
|
|
||||||
# converter._format_literal(prefix) + " (" +
|
|
||||||
# (response_rule or converter.not_literal("<tool_call>")) + " | " +
|
|
||||||
# converter._format_literal("<tool_call>") + " (" +
|
|
||||||
# ' | '.join(tool_rules) +
|
|
||||||
# ") " + converter._format_literal("</tool_call>") +
|
|
||||||
# ")") # + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
# @typechecked
|
|
||||||
def parse(self, s: str) -> Optional[Message]:
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
s = self.args.chat_template.strip_suffix(s)
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
|
|
||||||
|
|
@ -294,21 +275,14 @@ class ToolCallTagsChatHandler(ChatHandler):
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(**fc)))
|
function=FunctionCall(**fc)))
|
||||||
|
|
||||||
content = '\n'.join(content).strip()
|
content = '\n'.join(content).strip()
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls)
|
||||||
|
|
||||||
# if '<tool_call>'.startswith(ls) or ls.startswith('<tool_call>'):
|
|
||||||
# if ls.startswith('<tool_call>') and ls.endswith('</tool_call>' + suffix):
|
|
||||||
# tool_call = ls[len('<tool_call>'):-len('</tool_call>' + suffix)]
|
|
||||||
# return Message(role="assistant", content=None, tool_calls=[json.loads(tool_call)])
|
|
||||||
# return None
|
|
||||||
# else:
|
|
||||||
# return Message(role="assistant", content=s)
|
|
||||||
|
|
||||||
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, template: str, escapes_underscores=False, allow_parallel_calls=True):
|
def __init__(self, args: ChatHandlerArgs, template: str, parallel_calls: bool, escapes_underscores: bool = False):
|
||||||
super().__init__(args, escapes_underscores=escapes_underscores, allow_parallel_calls=allow_parallel_calls)
|
super().__init__(args, escapes_underscores=escapes_underscores, parallel_calls=parallel_calls)
|
||||||
assert '{tools}' in template, 'Template must contain "{tools}"'
|
assert '{tools}' in template, 'Template must contain "{tools}"'
|
||||||
|
|
||||||
self.output_format_prompt = Message(
|
self.output_format_prompt = Message(
|
||||||
|
|
@ -320,8 +294,8 @@ class TemplatedToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args, escapes_underscores=False, allow_parallel_calls=allow_parallel_calls)
|
super().__init__(args, escapes_underscores=False, parallel_calls=parallel_calls)
|
||||||
|
|
||||||
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
# Hackily import https://github.com/NousResearch/Hermes-Function-Calling
|
||||||
path = str(Path(__file__).parent / "hermes_function_calling")
|
path = str(Path(__file__).parent / "hermes_function_calling")
|
||||||
|
|
@ -330,15 +304,15 @@ class Hermes2ProToolsChatHandler(ToolCallTagsChatHandler):
|
||||||
from examples.openai.hermes_function_calling.prompter import PromptManager
|
from examples.openai.hermes_function_calling.prompter import PromptManager
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
raise ImportError(f"Please `git clone https://github.com/NousResearch/Hermes-Function-Calling {path}`")
|
||||||
|
|
||||||
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in tools])
|
prompt = PromptManager().generate_prompt(user_prompt=[], tools=[json.dumps(tool) for tool in args.tools])
|
||||||
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
assert len(prompt) == 1 and prompt[0]["role"] == "system"
|
||||||
self.output_format_prompt = Message(**prompt[0])
|
self.output_format_prompt = Message(**prompt[0])
|
||||||
|
|
||||||
class FunctionaryToolsChatHandler(ChatHandler):
|
class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
# Only allowing a single tool call at a time for now.
|
# Only allowing a single tool call at a time for now.
|
||||||
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
# Note that if there were more, they'd be separated by a '<|from|>assistant' literal
|
||||||
|
|
||||||
|
|
@ -347,7 +321,7 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
content= '// Supported function definitions that should be called when necessary.\n' +
|
content= '// Supported function definitions that should be called when necessary.\n' +
|
||||||
_tools_typescript_signatures(args.tools)
|
_tools_typescript_signatures(args.tools)
|
||||||
)
|
)
|
||||||
|
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
tool_rules = [
|
tool_rules = [
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
|
|
@ -355,17 +329,6 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
converter._format_literal(tool.function.name) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' +
|
||||||
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
converter.visit(tool.function.parameters, tool.function.name + '-args') + ' ' +
|
||||||
converter._format_literal('\n'))
|
converter._format_literal('\n'))
|
||||||
# converter.visit(
|
|
||||||
# dict(
|
|
||||||
# type="object",
|
|
||||||
# properties=dict(
|
|
||||||
# name=dict(const=tool.function.name),
|
|
||||||
# arguments=tool.function.parameters,
|
|
||||||
# ),
|
|
||||||
# required=['name', 'arguments']
|
|
||||||
# ),
|
|
||||||
# f'{tool.function.name}-tool-call'
|
|
||||||
# )
|
|
||||||
for i, tool in enumerate(self.args.tools)
|
for i, tool in enumerate(self.args.tools)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -378,33 +341,18 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
tool_call_without_start_rule = converter._add_rule(
|
tool_call_without_start_rule = converter._add_rule(
|
||||||
'tool_call_without_start',
|
'tool_call_without_start',
|
||||||
' | '.join(tool_rules))
|
' | '.join(tool_rules))
|
||||||
# + ' ' +
|
|
||||||
# converter.not_literal("all", dotall=False) + ' ' + converter._format_literal('\n<|content|>\n') + ' ' + not_from_rule + '*')
|
|
||||||
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
tool_call_rule = converter._add_rule('tool_call', f'{start_rule} {tool_call_without_start_rule}')
|
||||||
# converter._add_rule('root', f'({content_without_start_rule} ({content_rule})* ({tool_call_rule}+ {content_rule}*)? | {tool_call_without_start_rule} (* {tool_call_rule}{content_rule}*')
|
|
||||||
converter._add_rule(
|
converter._add_rule(
|
||||||
'root',
|
'root',
|
||||||
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
f'{content_without_start_rule} {content_rule}* ({tool_call_rule}+ {content_rule}*)? | '
|
||||||
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if allow_parallel_calls \
|
f'{tool_call_without_start_rule} {tool_call_rule}* {content_rule}*' if parallel_calls \
|
||||||
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
else f'{content_without_start_rule} {tool_call_rule}? | {tool_call_without_start_rule}')
|
||||||
|
|
||||||
self.grammar = converter.format_grammar()
|
self.grammar = converter.format_grammar()
|
||||||
# converter._add_rule(
|
|
||||||
# "root",
|
|
||||||
# converter._format_literal(prefix) + " (" +
|
|
||||||
# (response_rule or converter.not_literal("<|recipient|>")) + " | " +
|
|
||||||
# (' | '.join(
|
|
||||||
# converter._format_literal(f"<|recipient|>{tool.function.name}\n<|content|>") + " " +
|
|
||||||
# converter.visit(tool.function.parameters, tool.function.name + '-args')
|
|
||||||
# for tool in tools
|
|
||||||
# )) +
|
|
||||||
# ") " +
|
|
||||||
# ")") # + converter._format_literal(suffix))
|
|
||||||
|
|
||||||
# @typechecked
|
|
||||||
def parse(self, s: str) -> Optional[Message]:
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
s = self.args.chat_template.strip_suffix(s)
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
|
|
||||||
parts = _recipient_content_re.split(s)
|
parts = _recipient_content_re.split(s)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
return Message(role="assistant", content=s)
|
return Message(role="assistant", content=s)
|
||||||
|
|
@ -426,14 +374,14 @@ class FunctionaryToolsChatHandler(ChatHandler):
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=gen_callid(),
|
id=gen_callid(),
|
||||||
function=FunctionCall(name=recipient, arguments=arguments)))
|
function=FunctionCall(name=recipient, arguments=arguments)))
|
||||||
|
|
||||||
|
|
||||||
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
|
assert parts[-1].strip() in ('', '<|stop|>'), f'Unexpected content after tool calls: {parts[-1]}\nFull string: {s}'
|
||||||
|
|
||||||
content = '\n'.join(text_content).strip()
|
content = '\n'.join(text_content).strip()
|
||||||
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
return Message(role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None)
|
||||||
|
|
||||||
def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls):
|
def _make_bespoke_schema(response_schema, tool_call_schema, parallel_calls):
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -453,7 +401,7 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
||||||
# "const": "tool_calls"
|
# "const": "tool_calls"
|
||||||
# },
|
# },
|
||||||
"tool_calls": {
|
"tool_calls": {
|
||||||
"prefixItems": tool_call_schema if allow_parallel_calls \
|
"prefixItems": tool_call_schema if parallel_calls \
|
||||||
else [tool_call_schema],
|
else [tool_call_schema],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -474,9 +422,9 @@ def _make_bespoke_schema(response_schema, tool_call_schema, allow_parallel_calls
|
||||||
}
|
}
|
||||||
|
|
||||||
class BespokeToolsChatHandler(ChatHandler):
|
class BespokeToolsChatHandler(ChatHandler):
|
||||||
def __init__(self, args: ChatHandlerArgs, allow_parallel_calls: bool):
|
def __init__(self, args: ChatHandlerArgs, parallel_calls: bool):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
# args.response_schema = args.response_schema or {}
|
# args.response_schema = args.response_schema or {}
|
||||||
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
|
||||||
|
|
@ -497,7 +445,7 @@ class BespokeToolsChatHandler(ChatHandler):
|
||||||
for tool in self.args.tools
|
for tool in self.args.tools
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
allow_parallel_calls=allow_parallel_calls,
|
parallel_calls=parallel_calls,
|
||||||
),
|
),
|
||||||
'',
|
'',
|
||||||
)
|
)
|
||||||
|
|
@ -525,13 +473,12 @@ class BespokeToolsChatHandler(ChatHandler):
|
||||||
},
|
},
|
||||||
"required": ["name", "arguments"]
|
"required": ["name", "arguments"]
|
||||||
},
|
},
|
||||||
allow_parallel_calls=allow_parallel_calls,
|
parallel_calls=parallel_calls,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
# @typechecked
|
|
||||||
def parse(self, s: str) -> Optional[Message]:
|
def parse(self, s: str) -> Optional[Message]:
|
||||||
s = self.args.chat_template.strip_suffix(s)
|
s = self.args.chat_template.strip_suffix(s)
|
||||||
try:
|
try:
|
||||||
|
|
@ -579,19 +526,19 @@ _LONG_TEMPLATE='\n'.join([
|
||||||
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
# 'This is not hypothetical, you're not asked what you would do. If you need a tool called, just call it with <tool_call>...</tool_call>.''',
|
||||||
])
|
])
|
||||||
|
|
||||||
def get_chat_handler(args: ChatHandlerArgs, allow_parallel_calls=False) -> ChatHandler:
|
def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool) -> ChatHandler:
|
||||||
if not args.tools:
|
if not args.tools:
|
||||||
return NoToolsChatHandler(args)
|
return NoToolsChatHandler(args)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2:
|
||||||
return FunctionaryToolsChatHandler(args, allow_parallel_calls=False)
|
return FunctionaryToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_SHORT:
|
||||||
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
return TemplatedToolsChatHandler(args, _SHORT_TEMPLATE, parallel_calls=parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_LONG:
|
||||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, allow_parallel_calls=allow_parallel_calls)
|
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_MISTRAL:
|
||||||
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, escapes_underscores=True, allow_parallel_calls=allow_parallel_calls)
|
return TemplatedToolsChatHandler(args, _LONG_TEMPLATE, parallel_calls=parallel_calls, escapes_underscores=True)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_BESPOKE:
|
||||||
return BespokeToolsChatHandler(args, allow_parallel_calls=allow_parallel_calls)
|
return BespokeToolsChatHandler(args, parallel_calls=parallel_calls)
|
||||||
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
elif args.chat_template.tool_style == ToolsPromptStyle.TOOLS_HERMES_2_PRO:
|
||||||
return Hermes2ProToolsChatHandler(args)
|
return Hermes2ProToolsChatHandler(args)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def main(
|
||||||
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
# model_url: Annotated[Optional[str], typer.Option("--model-url", "-mu")] = None,
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
allow_parallel_calls: Optional[bool] = False,
|
parallel_calls: Optional[bool] = True,
|
||||||
auth: Optional[str] = None,
|
auth: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
|
|
@ -44,13 +44,13 @@ def main(
|
||||||
if endpoint:
|
if endpoint:
|
||||||
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
|
||||||
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
metadata = GGUFKeyValues(model)
|
metadata = GGUFKeyValues(model)
|
||||||
|
|
||||||
if not context_length:
|
if not context_length:
|
||||||
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
context_length = metadata[Keys.LLM.CONTEXT_LENGTH]
|
||||||
|
|
||||||
if Keys.Tokenizer.CHAT_TEMPLATE in metadata:
|
if Keys.Tokenizer.CHAT_TEMPLATE in metadata:
|
||||||
chat_template = ChatTemplate.from_gguf(metadata)
|
chat_template = ChatTemplate.from_gguf(metadata)
|
||||||
else:
|
else:
|
||||||
|
|
@ -92,22 +92,22 @@ def main(
|
||||||
|
|
||||||
chat_handler = get_chat_handler(
|
chat_handler = get_chat_handler(
|
||||||
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
|
ChatHandlerArgs(chat_template=chat_template, response_schema=response_schema, tools=chat_request.tools),
|
||||||
allow_parallel_calls=allow_parallel_calls
|
parallel_calls=parallel_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = chat_request.messages
|
messages = chat_request.messages
|
||||||
if chat_handler.output_format_prompt:
|
if chat_handler.output_format_prompt:
|
||||||
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
messages = chat_template.add_system_prompt(messages, chat_handler.output_format_prompt)
|
||||||
|
|
||||||
prompt = chat_template.render(messages, add_generation_prompt=True)
|
prompt = chat_template.render(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
sys.stderr.write(f'\n# REQUEST:\n\n{chat_request.model_dump_json(indent=2)}\n\n')
|
||||||
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
|
# sys.stderr.write(f'\n# MESSAGES:\n\n{TypeAdapter(list[Message]).dump_json(messages)}\n\n')
|
||||||
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
sys.stderr.write(f'\n# PROMPT:\n\n{prompt}\n\n')
|
||||||
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
|
sys.stderr.write(f'\n# GRAMMAR:\n\n{chat_handler.grammar}\n\n')
|
||||||
|
|
||||||
data = LlamaCppServerCompletionRequest(
|
data = LlamaCppServerCompletionRequest(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v
|
||||||
|
|
@ -130,7 +130,7 @@ def main(
|
||||||
json=data,
|
json=data,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=None)
|
timeout=None)
|
||||||
|
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
# TODO: Remove suffix from streamed response using partial parser.
|
# TODO: Remove suffix from streamed response using partial parser.
|
||||||
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"
|
assert not chat_request.tools and not chat_request.response_format, "Streaming not supported yet with tools or response_format"
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class SchemaToTypeScriptConverter:
|
||||||
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
|
[f"{self._desc_comment(additional_properties) if additional_properties else ''}[key: string]: {self.visit(additional_properties)}"]
|
||||||
if additional_properties is not None else []
|
if additional_properties is not None else []
|
||||||
)) + "}"
|
)) + "}"
|
||||||
|
|
||||||
def visit(self, schema: dict):
|
def visit(self, schema: dict):
|
||||||
def print_constant(v):
|
def print_constant(v):
|
||||||
return json.dumps(v)
|
return json.dumps(v)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue