Clean algorithm for calculate_diff_split; fix buggy expectations
This commit is contained in:
parent
e772822011
commit
24cc1bcd6d
|
|
@ -4,6 +4,7 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cctype>
|
||||
#include <numeric>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
|
@ -61,227 +62,128 @@ std::string trim_trailing_newlines(const std::string & str) {
|
|||
return str.substr(0, end);
|
||||
}
|
||||
|
||||
// Helper to find unmatched bracket/tag in a string
|
||||
// Finds an unmatched bracket in a string.
|
||||
// search_backwards=true: finds unclosed opening bracket at end (returns bracket position)
|
||||
// search_backwards=false: finds unopened closing bracket at start (returns position after bracket)
|
||||
static size_t find_unmatched_bracket(const std::string & str, bool search_backwards) {
|
||||
if (str.empty()) {
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
// Compute iteration bounds and bracket types based on direction
|
||||
const char * primary_brackets = search_backwards ? "<[" : ">]";
|
||||
|
||||
for (size_t i = 0; i < str.length(); ++i) {
|
||||
// Map iteration index to actual position based on direction
|
||||
size_t pos = search_backwards ? (str.length() - 1 - i) : i;
|
||||
char c = str[pos];
|
||||
|
||||
// Check if this is a primary bracket we're looking for
|
||||
if (c == primary_brackets[0] || c == primary_brackets[1]) {
|
||||
// Get the matching bracket: < matches >, [ matches ], and vice versa
|
||||
char match_bracket = (c == '<' || c == '>') ? (c == '<' ? '>' : '<') : (c == '[' ? ']' : '[');
|
||||
|
||||
// Search for matching bracket in the appropriate range
|
||||
size_t inner_start = search_backwards ? (pos + 1) : 0;
|
||||
size_t inner_end = search_backwards ? str.length() : pos;
|
||||
bool found_match = false;
|
||||
|
||||
for (size_t j = inner_start; j < inner_end; ++j) {
|
||||
if (str[j] == match_bracket) {
|
||||
found_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_match) {
|
||||
return search_backwards ? pos : (pos + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
static size_t find_unclosed_bracket_at_end(const std::string & str) {
|
||||
return find_unmatched_bracket(str, true);
|
||||
}
|
||||
|
||||
static size_t find_unopened_bracket_at_start(const std::string & str) {
|
||||
return find_unmatched_bracket(str, false);
|
||||
}
|
||||
|
||||
// Returns true if `s` contains an unmatched bracket.
|
||||
// search_backwards=true: looks for opening bracket without matching closing after it
|
||||
// search_backwards=false: looks for closing bracket without matching opening before it
|
||||
static bool contains_unmatched_bracket(const std::string & s, char opening, char closing, bool search_backwards) {
|
||||
if (s.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char primary = search_backwards ? opening : closing;
|
||||
|
||||
for (size_t i = 0; i < s.length(); ++i) {
|
||||
// Map iteration index to actual position based on direction
|
||||
size_t pos = search_backwards ? (s.length() - 1 - i) : i;
|
||||
|
||||
if (s[pos] == primary) {
|
||||
// Search for matching bracket in the appropriate range
|
||||
size_t inner_start = search_backwards ? (pos + 1) : 0;
|
||||
size_t inner_end = search_backwards ? s.length() : pos;
|
||||
char match_bracket = search_backwards ? closing : opening;
|
||||
bool found_match = false;
|
||||
|
||||
for (size_t j = inner_start; j < inner_end; ++j) {
|
||||
if (s[j] == match_bracket) {
|
||||
found_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_match) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool contains_unopened_closing(const std::string & s, char opening, char closing) {
|
||||
return contains_unmatched_bracket(s, opening, closing, false);
|
||||
}
|
||||
|
||||
static bool contains_unclosed_opening(const std::string & s, char opening, char closing) {
|
||||
return contains_unmatched_bracket(s, opening, closing, true);
|
||||
}
|
||||
|
||||
// Moves incomplete tags from prefix/suffix into left/right parts
|
||||
// Only moves tags when we detect the split pattern in BOTH left and right
|
||||
static diff_split fix_tag_boundaries(diff_split result) {
|
||||
// Check if prefix ends with an unclosed bracket/tag
|
||||
// No fixed window: search the entire neighboring strings for matching brackets
|
||||
size_t unclosed_pos = find_unclosed_bracket_at_end(result.prefix);
|
||||
if (unclosed_pos != std::string::npos) {
|
||||
char opening_bracket = result.prefix[unclosed_pos];
|
||||
char closing_bracket = (opening_bracket == '<') ? '>' : ']';
|
||||
|
||||
// Look for the specific closing bracket that matches our opening bracket
|
||||
bool left_has_pattern = contains_unopened_closing(result.left, opening_bracket, closing_bracket);
|
||||
bool right_has_pattern = contains_unopened_closing(result.right, opening_bracket, closing_bracket);
|
||||
bool suffix_has_pattern = contains_unopened_closing(result.suffix, opening_bracket, closing_bracket);
|
||||
|
||||
// Move the tag if both sides satisfy: has pattern OR is empty (and other has pattern)
|
||||
// This handles cases like: left="" right="_begin|>..." or left="stuff>" right="stuff>"
|
||||
bool left_satisfies = left_has_pattern || (result.left.empty() && suffix_has_pattern);
|
||||
bool right_satisfies = right_has_pattern || (result.right.empty() && suffix_has_pattern);
|
||||
|
||||
if (left_satisfies && right_satisfies) {
|
||||
// Move the unclosed tag from prefix to left/right
|
||||
std::string tag_part = result.prefix.substr(unclosed_pos);
|
||||
result.prefix = result.prefix.substr(0, unclosed_pos);
|
||||
result.left = tag_part + result.left;
|
||||
result.right = tag_part + result.right;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if suffix starts with an unopened bracket/tag
|
||||
size_t unopened_end = find_unopened_bracket_at_start(result.suffix);
|
||||
if (unopened_end != std::string::npos) {
|
||||
char closing_bracket =
|
||||
result.suffix[unopened_end - 1]; // -1 because unopened_end is position after the bracket
|
||||
char opening_bracket = (closing_bracket == '>') ? '<' : '[';
|
||||
|
||||
// Check if BOTH left and right have the pattern of unclosed opening bracket at the end
|
||||
bool left_has_pattern = contains_unclosed_opening(result.left, opening_bracket, closing_bracket);
|
||||
bool right_has_pattern = contains_unclosed_opening(result.right, opening_bracket, closing_bracket);
|
||||
bool prefix_has_pattern = contains_unclosed_opening(result.prefix, opening_bracket, closing_bracket);
|
||||
|
||||
// Move the tag if both sides satisfy: has pattern OR is empty (and other has pattern)
|
||||
bool left_satisfies = left_has_pattern || (result.left.empty() && prefix_has_pattern);
|
||||
bool right_satisfies = right_has_pattern || (result.right.empty() && prefix_has_pattern);
|
||||
|
||||
if (left_satisfies && right_satisfies) {
|
||||
// Move the unopened tag from suffix to left/right
|
||||
std::string tag_part = result.suffix.substr(0, unopened_end);
|
||||
result.suffix = result.suffix.substr(unopened_end);
|
||||
result.left = result.left + tag_part;
|
||||
result.right = result.right + tag_part;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
diff_split calculate_diff_split(const std::string & left, const std::string & right) {
|
||||
diff_split result;
|
||||
|
||||
// Find longest common prefix
|
||||
static size_t common_prefix_len(const std::string & left, const std::string & right) {
|
||||
size_t prefix_len = 0;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (prefix_len < min_len && left[prefix_len] == right[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
result.prefix = left.substr(0, prefix_len);
|
||||
return prefix_len;
|
||||
}
|
||||
|
||||
// Find longest common suffix, ending no later than the end of the longest common prefix
|
||||
static size_t common_suffix_len(const std::string & left, const std::string & right) {
|
||||
size_t suffix_len = 0;
|
||||
while (suffix_len < min_len - prefix_len) {
|
||||
size_t left_pos = left.length() - 1 - suffix_len;
|
||||
size_t right_pos = right.length() - 1 - suffix_len;
|
||||
size_t min_len = std::min(left.length(), right.length());
|
||||
while (suffix_len < min_len && left[left.length() - 1 - suffix_len] == right[right.length() - 1 - suffix_len]) {
|
||||
suffix_len++;
|
||||
}
|
||||
return suffix_len;
|
||||
}
|
||||
|
||||
// Ensure we're not going into the prefix region
|
||||
if (left_pos < prefix_len || right_pos < prefix_len) {
|
||||
break;
|
||||
diff_split calculate_diff_split(const std::string & left, const std::string & right) {
|
||||
diff_split result;
|
||||
|
||||
auto left_seg = segmentize_markers(left);
|
||||
auto right_seg = segmentize_markers(right);
|
||||
|
||||
if (left_seg.empty()) {
|
||||
result.right = right;
|
||||
return result;
|
||||
}
|
||||
if (right_seg.empty()) {
|
||||
result.left = left;
|
||||
return result;
|
||||
}
|
||||
|
||||
auto left_start = left_seg.begin();
|
||||
auto left_end = --left_seg.end();
|
||||
auto right_start = right_seg.begin();
|
||||
auto right_end = --right_seg.end();
|
||||
|
||||
auto test = [&] () {
|
||||
return left_start != left_end && right_start != right_end;
|
||||
};
|
||||
|
||||
bool left_fully_consumed = false;
|
||||
bool right_fully_consumed = false;
|
||||
|
||||
while (test()) {
|
||||
bool advanced = false;
|
||||
if (*left_start == *right_start) {
|
||||
result.prefix.append(left_start->value);
|
||||
left_start++;
|
||||
right_start++;
|
||||
advanced = true;
|
||||
}
|
||||
|
||||
if (left[left_pos] == right[right_pos]) {
|
||||
suffix_len++;
|
||||
} else {
|
||||
if (*left_end == *right_end) {
|
||||
result.suffix = left_end->value + result.suffix;
|
||||
if (left_start != left_end) {
|
||||
left_end--;
|
||||
} else {
|
||||
left_fully_consumed = true;
|
||||
}
|
||||
if (right_start != right_end) {
|
||||
right_end--;
|
||||
} else {
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
advanced = true;
|
||||
}
|
||||
if (!advanced) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
result.suffix = left.substr(left.length() - suffix_len);
|
||||
|
||||
// Extract the remainders (the parts between prefix and suffix)
|
||||
result.left = left.substr(prefix_len, left.length() - prefix_len - suffix_len);
|
||||
result.right = right.substr(prefix_len, right.length() - prefix_len - suffix_len);
|
||||
|
||||
// Fix tag boundaries by moving incomplete tags to left/right
|
||||
// We iterate because:
|
||||
// 1. fix_tag_boundaries may move content from prefix/suffix to left/right
|
||||
// 2. After that, we find common suffix in left/right to extract
|
||||
// 3. The extracted suffix might contain tag parts that need fixing
|
||||
// We apply fix AFTER suffix extraction to ensure incomplete tags aren't left in suffix
|
||||
diff_split prev_result;
|
||||
do {
|
||||
prev_result = result;
|
||||
|
||||
// First, find and extract any common suffix from left/right
|
||||
size_t suffix_len = 0;
|
||||
size_t min_len = std::min(result.left.length(), result.right.length());
|
||||
while (suffix_len < min_len) {
|
||||
size_t left_pos = result.left.length() - 1 - suffix_len;
|
||||
size_t right_pos = result.right.length() - 1 - suffix_len;
|
||||
if (result.left[left_pos] == result.right[right_pos]) {
|
||||
suffix_len++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if (left_start == left_end && right_start != right_end) {
|
||||
if (*left_start == *right_end) {
|
||||
result.suffix = right_end->value + result.suffix;
|
||||
right_end--;
|
||||
left_fully_consumed = true;
|
||||
} else if (*left_start == *right_start) {
|
||||
result.prefix.append(right_start->value);
|
||||
right_start++;
|
||||
left_fully_consumed = true;
|
||||
}
|
||||
|
||||
if (suffix_len > 0) {
|
||||
std::string common_suffix = result.left.substr(result.left.length() - suffix_len);
|
||||
result.suffix = common_suffix + result.suffix;
|
||||
result.left = result.left.substr(0, result.left.length() - suffix_len);
|
||||
result.right = result.right.substr(0, result.right.length() - suffix_len);
|
||||
} else if (right_start == right_end && left_start != left_end) {
|
||||
if (*left_end == *right_start) {
|
||||
result.suffix = left_end->value + result.suffix;
|
||||
left_end--;
|
||||
right_fully_consumed = true;
|
||||
} else if (*left_start == *right_start) {
|
||||
result.prefix.append(left_start->value);
|
||||
left_start++;
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
} else if (left_start == left_end && right_start == right_end && *left_start == *right_start && left_start->type == segment_type::MARKER) {
|
||||
result.prefix.append(right_start->value);
|
||||
left_fully_consumed = true;
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
|
||||
// Then apply fix_tag_boundaries to move incomplete tags from prefix/suffix to left/right
|
||||
result = fix_tag_boundaries(result);
|
||||
auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); };
|
||||
|
||||
} while (!(result == prev_result) && result.left != left && result.right != right);
|
||||
bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT;
|
||||
bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT;
|
||||
|
||||
std::string remainder_left = std::accumulate(left_start, left_fully_consumed ? left_end : ++left_end, std::string(), eat_segment);
|
||||
std::string remainder_right = std::accumulate(right_start, right_fully_consumed ? right_end : ++right_end, std::string(), eat_segment);
|
||||
|
||||
size_t suffix_len = can_have_text_suffix ? common_suffix_len(remainder_left, remainder_right) : 0;
|
||||
// avoid overlaps between prefix and suffix
|
||||
size_t prefix_len = can_have_text_prefix ? common_prefix_len(remainder_left.substr(0, remainder_left.size() - suffix_len),
|
||||
remainder_right.substr(0, remainder_right.size() - suffix_len)) : 0;
|
||||
|
||||
result.prefix.append(remainder_left.substr(0, prefix_len));
|
||||
result.suffix = remainder_left.substr(remainder_left.length() - suffix_len, suffix_len) + result.suffix;
|
||||
result.left = remainder_left.substr(prefix_len, remainder_left.length() - prefix_len - suffix_len);
|
||||
result.right = remainder_right.substr(prefix_len, remainder_right.length() - prefix_len - suffix_len);
|
||||
|
||||
if (result.left == "" && result.right == "") {
|
||||
// degenerate case, no diff
|
||||
result.prefix = left;
|
||||
result.suffix = "";
|
||||
// pick prefix = all as representation
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "chat-auto-parser-helpers.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
|
|
@ -381,17 +382,14 @@ void differential_analyzer::compare_thinking_enabled(const common_chat_template
|
|||
}
|
||||
}
|
||||
|
||||
// Check for slash-in-tag pattern: <think> vs </think>
|
||||
// diff shows: suffix="think>", left="/", right="" (or vice versa)
|
||||
if (reasoning.start.empty() && reasoning.end.empty()) {
|
||||
if (diff.right.empty() && trim_whitespace(diff.left) == "/") {
|
||||
auto seg_A = segmentize_markers(trim_trailing_whitespace(comparison->output_A));
|
||||
auto seg_B = segmentize_markers(trim_trailing_whitespace(comparison->output_B));
|
||||
if (!seg_A.empty() && !seg_B.empty() && seg_A[seg_A.size() - 1].type == segment_type::MARKER &&
|
||||
seg_B[seg_B.size() - 1].type == segment_type::MARKER) {
|
||||
reasoning.mode = reasoning_mode::FORCED_CLOSED;
|
||||
reasoning.start = seg_B[seg_B.size() - 1].value;
|
||||
reasoning.end = seg_A[seg_A.size() - 1].value;
|
||||
if (!diff.left.empty() && !diff.right.empty()) {
|
||||
auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left));
|
||||
auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right));
|
||||
if (seg_A.size() == 1 && seg_B.size() == 1) {
|
||||
reasoning.mode = reasoning_mode::FORCED_CLOSED;
|
||||
reasoning.start = seg_B[0].value;
|
||||
reasoning.end = seg_A[0].value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -739,7 +737,7 @@ void differential_analyzer::analyze_tool_call_format_json_native(const std::stri
|
|||
};
|
||||
// now let's check if we're in an array construction, mark it if so and get out of it
|
||||
if (json_start > 0 && space_or_bracket(true, clean_haystack[json_start - 1])) {
|
||||
for (--json_start; space_or_bracket(true, clean_haystack[json_start]) && json_start >= 0; json_start--) {
|
||||
for (--json_start; space_or_bracket(true, clean_haystack[json_start]) && json_start > 0; json_start--) {
|
||||
if (clean_haystack[json_start] == '[') {
|
||||
format.tools_array_wrapped = true;
|
||||
break;
|
||||
|
|
@ -900,7 +898,9 @@ void differential_analyzer::check_per_call_markers(const common_chat_template &
|
|||
return;
|
||||
}
|
||||
|
||||
std::string second_tool_content = trim_leading_whitespace(one_vs_two->diff.right);
|
||||
diff_split filter_common_call_part = calculate_diff_split(one_vs_two->diff.suffix, one_vs_two->diff.right);
|
||||
|
||||
std::string second_tool_content = trim_leading_whitespace(filter_common_call_part.right);
|
||||
if (!result.section_start.empty() &&
|
||||
second_tool_content.find(result.section_start) == 0) {
|
||||
result.per_call_start = result.section_start;
|
||||
|
|
@ -945,8 +945,6 @@ tool_function_analysis differential_analyzer::extract_function_markers(const com
|
|||
}
|
||||
|
||||
const auto & diff = comparison->diff;
|
||||
LOG_DBG("T3 diff - suffix: '%s'\n", diff.suffix.c_str());
|
||||
LOG_DBG("T3 diff - left: '%s', right: '%s'\n", diff.left.c_str(), diff.right.c_str());
|
||||
|
||||
if (diff.left.find("foofoo") != std::string::npos && diff.right.find("barbar") != std::string::npos) {
|
||||
std::string prefix_marker;
|
||||
|
|
@ -1371,8 +1369,6 @@ tool_id_analysis differential_analyzer::extract_call_id_markers(const common_cha
|
|||
}
|
||||
|
||||
const auto & diff = comparison->diff;
|
||||
LOG_DBG("T6 diff (call_id) - prefix: '%s', suffix: '%s'\n", diff.prefix.c_str(), diff.suffix.c_str());
|
||||
LOG_DBG("T6 diff (call_id) - left: '%s', right: '%s'\n", diff.left.c_str(), diff.right.c_str());
|
||||
|
||||
if (diff.left.empty() && diff.right.empty()) {
|
||||
return result;
|
||||
|
|
@ -1447,7 +1443,6 @@ tool_id_analysis differential_analyzer::extract_call_id_markers(const common_cha
|
|||
for (size_t i = 0; i < suffix_segments.size(); i++) {
|
||||
if (suffix_segments[i].type == segment_type::MARKER) {
|
||||
result.suffix = suffix_segments[i].value;
|
||||
LOG_DBG("T6: call_id_suffix='%s'\n", result.suffix.c_str());
|
||||
break;
|
||||
}
|
||||
// Stop if we hit the args
|
||||
|
|
@ -1468,7 +1463,6 @@ tool_id_analysis differential_analyzer::extract_call_id_markers(const common_cha
|
|||
for (int i = (int) segments.size() - 1; i >= 0; i--) {
|
||||
if (segments[i].type == segment_type::MARKER) {
|
||||
result.prefix = segments[i].value;
|
||||
LOG_DBG("T6: call_id_prefix='%s'\n", result.prefix.c_str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -379,4 +379,12 @@ struct segment {
|
|||
std::string value;
|
||||
|
||||
segment(segment_type type, std::string value) : type(type), value(std::move(value)) {}
|
||||
|
||||
bool operator==(const segment & other) const {
|
||||
return type == other.type && value == other.value;
|
||||
}
|
||||
|
||||
bool operator!=(const segment & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -217,6 +217,12 @@ static void test_calculate_diff_split_identical(testing & t) {
|
|||
t.assert_equal("left should be empty", "", result.left);
|
||||
t.assert_equal("right should be empty", "", result.right);
|
||||
t.assert_equal("suffix should be empty", "", result.suffix);
|
||||
|
||||
result = calculate_diff_split("<row><row><row><your><boat><gently>", "<row><row><row><your><boat><gently>");
|
||||
t.assert_equal("prefix should be '<row><row><row><your><boat><gently>'", "<row><row><row><your><boat><gently>", result.prefix);
|
||||
t.assert_equal("left should be empty", "", result.left);
|
||||
t.assert_equal("right should be empty", "", result.right);
|
||||
t.assert_equal("suffix should be empty", "", result.suffix);
|
||||
}
|
||||
|
||||
static void test_calculate_diff_split_common_prefix(testing & t) {
|
||||
|
|
@ -894,8 +900,8 @@ static void test_seed_oss_call_count(testing & t) {
|
|||
t.assert_true("T2 right should contain value4", diff.right.find("value4") != std::string::npos);
|
||||
t.assert_true("T2 right should contain second tool_call end", diff.right.find("</seed:tool_call>") != std::string::npos);
|
||||
|
||||
// Suffix should be the eos token
|
||||
t.assert_equal("T2 suffix should be '<seed:eos>'", "<seed:eos>", diff.suffix);
|
||||
// Suffix should end with the eos token
|
||||
t.assert_equal("T2 suffix should end with '<seed:eos>'", "<seed:eos>", diff.suffix.substr(diff.suffix.length() - 10, 10));
|
||||
}
|
||||
|
||||
// T3: Compare different function names
|
||||
|
|
|
|||
Loading…
Reference in New Issue