llama.cpp/gguf-py/scripts/gguf-template.py

141 lines
4.2 KiB
Python

"""
gguf-template.py - example file to extract the chat template from the models metadata
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
from pathlib import Path
import jinja2
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
from gguf.constants import Keys
from gguf.gguf_reader import GGUFReader # noqa: E402
# Configure logging
logger = logging.getLogger("gguf-chat-template")
def get_chat_template(model_file: str) -> str:
reader = GGUFReader(model_file)
# Available keys
logger.debug("Detected model metadata!")
logger.debug("Outputting available model fields:")
for key in reader.fields.keys():
logger.debug(key)
# Access the 'chat_template' field directly using its key
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)
if chat_template_field:
# Extract the chat template string from the field
chat_template_memmap = chat_template_field.parts[-1]
chat_template_string = chat_template_memmap.tobytes().decode("utf-8")
return chat_template_string
else:
logger.error("Chat template field not found in model metadata.")
return ""
def render_chat_template(
chat_template: str, bos_token: str, eos_token: str, render_template: bool = False
) -> str:
"""
Display the chat template to standard output, optionally formatting it using Jinja2.
Args:
chat_template (str): The extracted chat template.
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
"""
logger.debug(f"Render Template: {render_template}")
if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
env = jinja2.Environment(
loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True
)
template = env.from_string(chat_template)
messages = [
{"role": "system", "content": "I am a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How may I assist you today?"},
{"role": "user", "content": "Can you tell me what pickled mayonnaise is?"},
{"role": "assistant", "content": "Certainly! What would you like to know about it?"},
]
try:
formatted_template = template.render(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
)
except jinja2.exceptions.UndefinedError:
# system message is incompatible with set format
formatted_template = template.render(
messages=messages[1:],
bos_token=bos_token,
eos_token=eos_token,
)
return formatted_template
else:
# Display the raw template
return chat_template
# Example usage:
def main():
parser = argparse.ArgumentParser(
description="Extract chat template from a GGUF model file"
)
parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
parser.add_argument(
"-r",
"--render-template",
action="store_true",
help="Render the chat template using Jinja2",
)
parser.add_argument(
"-b",
"--bos",
default="<s>",
help="Set a bos special token. Default is '<s>'.",
)
parser.add_argument(
"-e",
"--eos",
default="</s>",
help="Set a eos special token. Default is '</s>'.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Output model keys",
)
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
chat_template = get_chat_template(args.model_file)
rendered_template = render_chat_template(
chat_template, args.bos, args.eos, render_template=args.render_template
)
print(rendered_template)
if __name__ == "__main__":
main()