diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index 5b697eb949..faee1559cf 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -574,6 +574,16 @@ struct vm { value_string gather_string_parts(const value & val) { value_string parts = mk_val(); gather_string_parts_recursive(val, parts); + // join consecutive parts with the same type + auto & p = parts->val_str.parts; + for (size_t i = 1; i < p.size(); ) { + if (p[i].is_input == p[i - 1].is_input) { + p[i - 1].val += p[i].val; + p.erase(p.begin() + i); + } else { + i++; + } + } return parts; } }; diff --git a/tests/test-chat-jinja.cpp b/tests/test-chat-jinja.cpp index 50401b56bb..86fe8f1f15 100644 --- a/tests/test-chat-jinja.cpp +++ b/tests/test-chat-jinja.cpp @@ -56,7 +56,8 @@ std::string DEFAULT_JSON = R"({ ], "bos_token": "", "eos_token": "", - "tools": [] + "tools": [], + "add_generation_prompt": true })"; int main(int argc, char ** argv) { @@ -181,7 +182,7 @@ void run_single(std::string contents, json input) { auto parts = vm.gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; - for (const auto & part : parts.get()->val_str.parts) { + for (const auto & part : parts->as_string().parts) { std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n"; } }