llama.cpp/gguf-py/scripts/gguf-template.py

138 lines
4.2 KiB
Python

"""
gguf-template.py - example file to extract the chat template from the models metadata
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
from pathlib import Path
import jinja2
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
from gguf.constants import Keys
from gguf.gguf_reader import GGUFReader # noqa: E402
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gguf-chat-template")
def get_chat_template(model_file: str, verbose: bool = False) -> str:
reader = GGUFReader(model_file)
# Available keys
logger.info("Detected model metadata!")
if verbose:
logger.info("Outputting available model fields:")
for key in reader.fields.keys():
logger.info(key)
# Access the 'chat_template' field directly using its key
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)
if chat_template_field:
# Extract the chat template string from the field
chat_template_memmap = chat_template_field.parts[-1]
chat_template_string = chat_template_memmap.tobytes().decode("utf-8")
return chat_template_string
else:
logger.error("Chat template field not found in model metadata.")
return ""
def display_chat_template(
chat_template: str, bos_token: str, eos_token: str, render_template: bool = False
):
"""
Display the chat template to standard output, optionally formatting it using Jinja2.
Args:
chat_template (str): The extracted chat template.
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
"""
logger.info(f"Format Template: {render_template}")
if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
env = jinja2.Environment(
loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True
)
template = env.from_string(chat_template)
messages = [
{"role": "system", "content": "I am a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How may I assist you today?"},
{"role": "user", "content": "Can you tell me what pickled mayonnaise is?"},
{"role": "assistant", "content": "Certainly! What would you like to know about it?"},
]
try:
formatted_template = template.render(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
)
except jinja2.exceptions.UndefinedError:
# system message is incompatible with set format
formatted_template = template.render(
messages=messages[1:],
bos_token=bos_token,
eos_token=eos_token,
)
print(formatted_template)
else:
# Display the raw template
print(chat_template)
# Example usage:
def main():
parser = argparse.ArgumentParser(
description="Extract chat template from a GGUF model file"
)
parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
parser.add_argument(
"-r",
"--render-template",
action="store_true",
help="Render the chat template using Jinja2",
)
parser.add_argument(
"-b",
"--bos",
default="<s>",
help="Set a bos special token. Default is '<s>'.",
)
parser.add_argument(
"-e",
"--eos",
default="</s>",
help="Set a eos special token. Default is '</s>'.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Output model keys",
)
args = parser.parse_args()
model_file = args.model_file
chat_template = get_chat_template(model_file, args.verbose)
display_chat_template(
chat_template, args.bos, args.eos, render_template=args.render_template
)
if __name__ == "__main__":
main()