From 8156d549f6b57c5c0a9d3ed61b6e344cf016a5f2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Feb 2026 19:45:04 +0200 Subject: [PATCH] sim : fix answer matching --- examples/llama-eval/llama-eval-new.py | 3 +- examples/llama-eval/llama-server-simulator.py | 59 +++++++++++-------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index d3c318e151..3f202a952b 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -28,8 +28,7 @@ GRADER_PATTERNS = { } TEMPLATE_REGISTRY = { - "aime": """ -{question} + "aime": """{question} Please reason step by step, and put your final answer within \\boxed{{}}. """, } diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py index 4958683013..210683953e 100755 --- a/examples/llama-eval/llama-server-simulator.py +++ b/examples/llama-eval/llama-server-simulator.py @@ -19,25 +19,28 @@ cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) os.environ["HF_DATASETS_CACHE"] = str(cache_dir) -def levenshtein_distance(s1: str, s2: str) -> int: - """Calculate Levenshtein distance between two strings""" - if len(s1) < len(s2): - return levenshtein_distance(s2, s1) +def dice(s1: str, s2: str) -> float: + """Calculate Dice coefficient between two strings based on bigram overlap.""" + if not s1 and not s2: + return 1.0 - if len(s2) == 0: - return len(s1) + def _bigrams(s: str): + return [s[i : i + 2] for i in range(len(s) - 1)] - previous_row = range(len(s2) + 1) - for i, c1 in enumerate(s1): - current_row = [i + 1] - for j, c2 in enumerate(s2): - insertions = previous_row[j + 1] + 1 - deletions = current_row[j] + 1 - substitutions = previous_row[j] + (c1 != c2) - current_row.append(min(insertions, deletions, substitutions)) - previous_row = current_row + bigrams1 = _bigrams(s1) + bigrams2 = _bigrams(s2) - return previous_row[-1] + if not bigrams1 and not bigrams2: + return 1.0 + + from collections import Counter + + freq1 = Counter(bigrams1) + freq2 = Counter(bigrams2) + + intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1) + dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2)) + return dice_coeff def debug_log(message: str): """Log debug messages to both stdout and a file""" @@ -54,6 +57,12 @@ class EvalState: task_states: Dict[str, Dict] sampling_config: Dict +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + class AimeDataset: def __init__(self, split: str = "train"): self.split = split @@ -75,7 +84,7 @@ class AimeDataset: def find_question(self, request_text: str) -> Optional[Dict]: best_match = None - best_distance = float('inf') + best_distance = -1 best_index = -1 for i, question in enumerate(self.questions): @@ -97,16 +106,14 @@ class AimeDataset: # Calculate Levenshtein distance for partial matches # Only consider if request is at least 50% of question length if len(request_lower) >= len(question_lower) * 0.5: - distance = levenshtein_distance(question_lower, request_lower) - # Normalize distance by length - normalized_distance = distance / len(question_lower) + distance = dice(question_lower, request_lower) - if normalized_distance < best_distance: - best_distance = normalized_distance + if distance > best_distance: + best_distance = distance best_match = question best_index = i - if best_match and best_distance < 0.3: # Threshold for partial match + if best_match and best_distance > 0.3: # Threshold for partial match debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") return best_match @@ -114,7 +121,11 @@ class AimeDataset: return None def get_answer(self, question: Dict) -> str: - return str(question["answer"]) + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) class Simulator: def __init__(