lexer
This commit is contained in:
parent
8d8030142e
commit
15b7c50e95
|
|
@ -0,0 +1,79 @@
|
|||
#include "common.h"
|
||||
#include <chat-peg-parser.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
struct compiler {
|
||||
common_chat_peg_native_builder builder;
|
||||
common_peg_parser root;
|
||||
|
||||
compiler() : root(builder.choice()) {
|
||||
auto & p = builder;
|
||||
|
||||
auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
|
||||
auto num = p.rule("num", p.chars("[0-9]", 1, -1));
|
||||
|
||||
//
|
||||
// expressions
|
||||
//
|
||||
|
||||
auto expression = p.choice();
|
||||
|
||||
auto var_name = p.rule("var_name", p.chars("[a-zA-Z_]", 1, -1) << p.chars("[a-zA-Z0-9_]", 0, -1));
|
||||
expression |= var_name;
|
||||
|
||||
// value
|
||||
auto p_int = p.rule("value_int", num);
|
||||
auto p_flt = p.rule("value_flt", num << "." << p.optional(num));
|
||||
auto p_str = p.rule("value_str",
|
||||
p.json_string() |
|
||||
p.literal("'") + p.chars("[^']*", 0, -1) + p.literal("'")
|
||||
);
|
||||
|
||||
expression |= p_int;
|
||||
expression |= p_flt;
|
||||
expression |= p_str;
|
||||
|
||||
// function calls
|
||||
auto p_args = p.rule("args", expression << ws << p.zero_or_more("," << ws << expression));
|
||||
auto p_func = p.rule("func", ws << var_name << ws << "(" << ws << p_args << ws << ")");
|
||||
expression |= p_func;
|
||||
|
||||
// indexing
|
||||
auto p_idx = p.rule("idx", ws << "[" << ws << expression << ws << "]");
|
||||
expression |= p_idx;
|
||||
|
||||
// set
|
||||
auto p_set = p.rule("set", "set " << ws << var_name << ws << "=" << expression);
|
||||
expression |= p_set;
|
||||
|
||||
// if, else, endif
|
||||
auto p_if = p.rule("if", "if " << ws << expression << ws);
|
||||
auto p_else = p.rule("else", "else " << ws << expression << ws);
|
||||
auto p_endif = p.rule("endif", p.literal("endif"));
|
||||
|
||||
expression |= p_if;
|
||||
expression |= p_else;
|
||||
expression |= p_endif;
|
||||
|
||||
expression = p.space() + expression + p.space();
|
||||
|
||||
//
|
||||
// root
|
||||
//
|
||||
|
||||
// auto strip = p.rule("strip", "-" << expression << "-");
|
||||
auto print = p.rule("print", "{{" << (expression) << "}}");
|
||||
auto ctrl = p.rule("ctrl", "{%" << (expression) << "%}");
|
||||
|
||||
root |= print;
|
||||
root |= ctrl;
|
||||
root |= p.rule("text", p.negate(root));
|
||||
|
||||
root = p.one_or_more(root);
|
||||
root += p.end();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <cctype>
|
||||
#include <functional>
|
||||
|
||||
// #define JJ_DEBUG(msg, ...) printf("jinja-lexer: " msg "\n", __VA_ARGS__)
|
||||
#define JJ_DEBUG(msg, ...) // no-op
|
||||
|
||||
namespace jinja {
|
||||
|
||||
struct preprocess_options {
|
||||
bool trim_blocks = false;
|
||||
bool lstrip_blocks = false;
|
||||
};
|
||||
|
||||
struct token {
|
||||
enum type {
|
||||
undefined,
|
||||
text, // The text between Jinja statements or expressions
|
||||
|
||||
numeric_literal, // e.g., 123, 1.0
|
||||
string_literal, // 'string'
|
||||
identifier, // Variables, functions, statements, booleans, etc.
|
||||
equals, // =
|
||||
open_paren, // (
|
||||
close_paren, // )
|
||||
open_statement, // {%
|
||||
close_statement, // %}
|
||||
open_expression, // {{
|
||||
close_expression, // }}
|
||||
open_square_bracket, // [
|
||||
close_square_bracket, // ]
|
||||
open_curly_bracket, // {
|
||||
close_curly_bracket, // }
|
||||
comma, // ,
|
||||
dot, // .
|
||||
colon, // :
|
||||
pipe, // |
|
||||
|
||||
call_operator, // ()
|
||||
additive_binary_operator, // + - ~
|
||||
multiplicative_binary_operator, // * / %
|
||||
comparison_binary_operator, // < > <= >= == !=
|
||||
unary_operator, // ! - +
|
||||
comment, // {# ... #}
|
||||
};
|
||||
type t;
|
||||
std::string value;
|
||||
};
|
||||
|
||||
struct lexer {
|
||||
const std::map<char, char> escape_chars = {
|
||||
{'n', '\n'},
|
||||
{'t', '\t'},
|
||||
{'r', '\r'},
|
||||
{'b', '\b'},
|
||||
{'f', '\f'},
|
||||
{'v', '\v'},
|
||||
{'\\', '\\'},
|
||||
{'\'', '\''},
|
||||
{'\"', '\"'},
|
||||
};
|
||||
|
||||
static bool is_word(char c) {
|
||||
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
|
||||
}
|
||||
|
||||
static bool is_integer(char c) {
|
||||
return std::isdigit(static_cast<unsigned char>(c));
|
||||
}
|
||||
|
||||
const std::vector<std::pair<std::string, token::type>> ordered_mapping_table = {
|
||||
// Control sequences
|
||||
{"{%", token::open_statement},
|
||||
{"%}", token::close_statement},
|
||||
{"{{", token::open_expression},
|
||||
{"}}", token::close_expression},
|
||||
// Single character tokens
|
||||
{"(", token::open_paren},
|
||||
{")", token::close_paren},
|
||||
{"{", token::open_curly_bracket},
|
||||
{"}", token::close_curly_bracket},
|
||||
{"[", token::open_square_bracket},
|
||||
{"]", token::close_square_bracket},
|
||||
{",", token::comma},
|
||||
{".", token::dot},
|
||||
{":", token::colon},
|
||||
{"|", token::pipe},
|
||||
// Comparison operators
|
||||
{"<=", token::comparison_binary_operator},
|
||||
{">=", token::comparison_binary_operator},
|
||||
{"==", token::comparison_binary_operator},
|
||||
{"!=", token::comparison_binary_operator},
|
||||
{"<", token::comparison_binary_operator},
|
||||
{">", token::comparison_binary_operator},
|
||||
// Arithmetic operators
|
||||
{"+", token::additive_binary_operator},
|
||||
{"-", token::additive_binary_operator},
|
||||
{"~", token::additive_binary_operator},
|
||||
{"*", token::multiplicative_binary_operator},
|
||||
{"/", token::multiplicative_binary_operator},
|
||||
{"%", token::multiplicative_binary_operator},
|
||||
// Assignment operator
|
||||
{"=", token::equals},
|
||||
};
|
||||
|
||||
std::string preprocess(const std::string& template_str, const preprocess_options& options) const {
|
||||
std::string result = template_str;
|
||||
// According to https://jinja.palletsprojects.com/en/3.0.x/templates/#whitespace-control
|
||||
|
||||
// In the default configuration:
|
||||
// - a single trailing newline is stripped if present
|
||||
// - other whitespace (spaces, tabs, newlines etc.) is returned unchanged
|
||||
if (!result.empty() && result.back() == '\n') {
|
||||
result.pop_back();
|
||||
}
|
||||
|
||||
if (options.lstrip_blocks) {
|
||||
// The lstrip_blocks option can also be set to strip tabs and spaces from the
|
||||
// beginning of a line to the start of a block. (Nothing will be stripped if
|
||||
// there are other characters before the start of the block.)
|
||||
// result = std::regex_replace(result, std::regex(R"((?m)^[ \t]*(\{[#%-]))"), "$1");
|
||||
throw std::runtime_error("lstrip_blocks option is not implemented yet");
|
||||
}
|
||||
|
||||
if (options.trim_blocks) {
|
||||
// If an application configures Jinja to trim_blocks, the first newline after
|
||||
// a template tag is removed automatically (like in PHP).
|
||||
result = std::regex_replace(result, std::regex(R"(([#%-]\})\n)"), "$1");
|
||||
}
|
||||
|
||||
// Handle whitespace control with - in tags
|
||||
result = std::regex_replace(result, std::regex(R"(-%\}\s*)"), "%}");
|
||||
result = std::regex_replace(result, std::regex(R"(\s*\{%-)"), "{%");
|
||||
result = std::regex_replace(result, std::regex(R"(-\}\}\s*)"), "}}");
|
||||
result = std::regex_replace(result, std::regex(R"(\s*\{\{-)"), "{{");
|
||||
result = std::regex_replace(result, std::regex(R"(-#\}\s*)"), "#}");
|
||||
result = std::regex_replace(result, std::regex(R"(\s*\{\#-)"), "{#");
|
||||
|
||||
// Handle custom transformers-specific `generation` tag
|
||||
// See https://github.com/huggingface/transformers/pull/30650 for more information.
|
||||
// result = std::regex_replace(result, std::regex(R"((?s)\{%\s*generation\s*%\}.+?\{%\s*endgeneration\s*%\})"), "");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<token> tokenize(const std::string & input, const preprocess_options & options = {}) {
|
||||
std::vector<token> tokens;
|
||||
std::string src = preprocess(input, options);
|
||||
JJ_DEBUG("preprocessed input: '%s'", src.c_str());
|
||||
|
||||
size_t pos = 0;
|
||||
size_t curly_bracket_depth = 0;
|
||||
|
||||
using pred = std::function<bool(char)>;
|
||||
auto consume_while = [&](pred predicate) -> std::string {
|
||||
std::string str;
|
||||
while (predicate(src[pos])) {
|
||||
// check for escape char
|
||||
if (src[pos] == '\\') {
|
||||
// consume backslash
|
||||
++pos;
|
||||
// check for end of input
|
||||
if (pos >= src.size()) {
|
||||
throw std::runtime_error("lexer: unexpected end of input after escape character");
|
||||
}
|
||||
// add escaped char
|
||||
char escaped_char = src[pos++];
|
||||
if (escape_chars.find(escaped_char) == escape_chars.end()) {
|
||||
throw std::runtime_error(std::string("lexer: unknown escape character \\") + escaped_char);
|
||||
}
|
||||
char unescaped_char = escape_chars.at(escaped_char);
|
||||
str += unescaped_char;
|
||||
continue;
|
||||
}
|
||||
|
||||
str += src[pos++];
|
||||
if (pos > src.size()) {
|
||||
throw std::runtime_error("lexer: unexpected end of input during consume_while");
|
||||
}
|
||||
}
|
||||
return str;
|
||||
};
|
||||
|
||||
auto next_pos_is = [&](std::initializer_list<char> chars) -> bool {
|
||||
if (pos + 1 >= src.size()) return false;
|
||||
for (char c : chars) {
|
||||
if (src[pos + 1] == c) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
while (pos < src.size()) {
|
||||
JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
|
||||
|
||||
// First, consume all text that is outside of a Jinja statement or expression
|
||||
token::type last_token_type = tokens.empty()
|
||||
? token::undefined
|
||||
: tokens.back().t;
|
||||
if (last_token_type == token::undefined ||
|
||||
last_token_type == token::close_statement ||
|
||||
last_token_type == token::close_expression ||
|
||||
last_token_type == token::comment) {
|
||||
std::string text;
|
||||
while (pos < src.size() &&
|
||||
// Keep going until we hit the next Jinja statement or expression
|
||||
!(
|
||||
src[pos] == '{' &&
|
||||
next_pos_is( {'%', '{', '#'} )
|
||||
)) {
|
||||
text += src[pos++];
|
||||
}
|
||||
JJ_DEBUG("consumed text: '%s'", text.c_str());
|
||||
if (!text.empty()) {
|
||||
tokens.push_back({token::text, text});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Possibly consume a comment
|
||||
if (src[pos] == '{' && next_pos_is( {'#'} )) {
|
||||
pos += 2; // Skip the opening {#
|
||||
std::string comment;
|
||||
while (!(src[pos] == '#' && next_pos_is( {'}'} ))) {
|
||||
if (pos + 2 >= src.size()) {
|
||||
throw std::runtime_error("lexer: missing end of comment tag");
|
||||
}
|
||||
comment += src[pos++];
|
||||
}
|
||||
JJ_DEBUG("consumed comment: '%s'", comment.c_str());
|
||||
tokens.push_back({token::comment, comment});
|
||||
pos += 2; // Skip the closing #}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Consume (and ignore) all whitespace inside Jinja statements or expressions
|
||||
consume_while([](char c) { return std::isspace(static_cast<unsigned char>(c)); });
|
||||
|
||||
if (pos >= src.size()) break;
|
||||
|
||||
char ch = src[pos];
|
||||
|
||||
// Check for unary operators
|
||||
if (ch == '-' || ch == '+') {
|
||||
token::type last_token_type = tokens.empty() ? token::undefined : tokens.back().t;
|
||||
if (last_token_type == token::text || last_token_type == token::undefined) {
|
||||
throw std::runtime_error(std::string("lexer: unexpected character: ") + ch);
|
||||
}
|
||||
switch (last_token_type) {
|
||||
case token::identifier:
|
||||
case token::numeric_literal:
|
||||
case token::string_literal:
|
||||
case token::close_paren:
|
||||
case token::close_square_bracket:
|
||||
// Part of a binary operator
|
||||
// a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1
|
||||
// Continue parsing normally
|
||||
break;
|
||||
default: {
|
||||
// Is part of a unary operator
|
||||
// (-1), [-1], (1 + -1), not -1, -apple
|
||||
++pos; // Consume the operator
|
||||
|
||||
// Check for numbers following the unary operator
|
||||
std::string num = consume_while(is_integer);
|
||||
std::string value = std::string(1, ch) + num;
|
||||
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
|
||||
JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
|
||||
tokens.push_back({t, value});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to match one of the tokens in the mapping table
|
||||
bool matched = false;
|
||||
for (const auto & [seq, typ] : ordered_mapping_table) {
|
||||
// Inside an object literal, don't treat "}}" as expression-end
|
||||
if (seq == "}}" && curly_bracket_depth > 0) {
|
||||
continue;
|
||||
}
|
||||
if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) {
|
||||
tokens.push_back({typ, seq});
|
||||
if (typ == token::open_expression) {
|
||||
curly_bracket_depth = 0;
|
||||
} else if (typ == token::open_curly_bracket) {
|
||||
++curly_bracket_depth;
|
||||
} else if (typ == token::close_curly_bracket) {
|
||||
--curly_bracket_depth;
|
||||
}
|
||||
pos += seq.size();
|
||||
matched = true;
|
||||
break; // continue main loop
|
||||
}
|
||||
}
|
||||
if (matched) continue; // continue main loop
|
||||
|
||||
// Strings
|
||||
if (ch == '\'' || ch == '"') {
|
||||
++pos; // Skip opening quote
|
||||
std::string str = consume_while([ch](char c) { return c != ch; });
|
||||
tokens.push_back({token::string_literal, str});
|
||||
++pos; // Skip closing quote
|
||||
continue;
|
||||
}
|
||||
|
||||
// Numbers
|
||||
if (is_integer(ch)) {
|
||||
std::string num = consume_while(is_integer);
|
||||
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
|
||||
++pos; // Consume '.'
|
||||
std::string frac = consume_while(is_integer);
|
||||
num += "." + frac;
|
||||
}
|
||||
tokens.push_back({token::numeric_literal, num});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Identifiers
|
||||
if (is_word(ch)) {
|
||||
std::string word = consume_while(is_word);
|
||||
tokens.push_back({token::identifier, word});
|
||||
continue;
|
||||
}
|
||||
|
||||
throw std::runtime_error(std::string("lexer: unexpected character: ") + ch);
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
|
@ -186,6 +186,7 @@ endif()
|
|||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-chat-jinja.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <iostream>
|
||||
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
|
||||
#include "peg-parser.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "jinja/jinja-compiler.h"
|
||||
#include "jinja/jinja-lexer.h"
|
||||
|
||||
int main(void) {
|
||||
std::string contents = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}";
|
||||
|
||||
std::cout << "=== INPUT ===\n" << contents << "\n\n";
|
||||
|
||||
jinja::lexer lexer;
|
||||
jinja::preprocess_options options;
|
||||
options.trim_blocks = true;
|
||||
options.lstrip_blocks = false;
|
||||
auto tokens = lexer.tokenize(contents, options);
|
||||
for (const auto & tok : tokens) {
|
||||
std::cout << "token: type=" << static_cast<int>(tok.t) << " text='" << tok.value << "'\n";
|
||||
}
|
||||
|
||||
// jinja::compiler compiler;
|
||||
// compiler.builder.set_root(compiler.root);
|
||||
// auto parser = compiler.builder.build();
|
||||
|
||||
// auto grammar = build_grammar([&](const common_grammar_builder & builder0) {
|
||||
// parser.build_grammar(builder0);
|
||||
// });
|
||||
// printf("== GRAMMAR ==\n");
|
||||
// printf("%s\n", grammar.c_str());
|
||||
|
||||
// // printf("== DUMP ==\n");
|
||||
// // printf("%s\n", parser.dump(compiler.root.id()).c_str());
|
||||
|
||||
// printf("== PARSE ==\n");
|
||||
|
||||
// common_peg_parse_context ctx(contents);
|
||||
// const auto result = parser.parse(ctx);
|
||||
// if (!result.success()) {
|
||||
// throw std::runtime_error("failed to parse, type = " + std::to_string(result.type));
|
||||
// }
|
||||
|
||||
// ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
// printf("node: rule='%s' text='%s'\n", node.rule.c_str(), std::string(node.text).c_str());
|
||||
// });
|
||||
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue