grade : improve regex + logs

This commit is contained in:
Georgi Gerganov 2026-02-16 11:51:36 +02:00
parent 52759bf078
commit db10dda1f3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 63 additions and 9 deletions

View File

@ -325,18 +325,30 @@ class Grader:
def _get_pattern(self) -> Optional[str]:
if self.grader_type == "regex":
return GRADER_PATTERNS.get(self.grader_type) # Use grader_type as key
return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key
return None
def _extract_answer_regex(self, pred: str) -> Optional[str]:
"""Extract answer using regex pattern"""
if not self.pattern:
return None
# For AIME datasets, prioritize boxed answers
if self.dataset_type in ["aime", "aime2025"]:
boxed_pattern = r'\\boxed{([^}]+)}'
boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE)
if boxed_matches:
# Return the last boxed answer found (most likely the final answer)
return boxed_matches[-1].strip()
# For other datasets, search for numbers from the end of the text
# This prioritizes numbers that appear later in the response
matches = re.findall(self.pattern, pred, re.IGNORECASE)
if not matches:
return None
for match in matches:
# Process matches from end to start
for match in reversed(matches):
if isinstance(match, tuple):
match = match[0] if match[0] else match[1]
extracted = match.strip()
@ -446,7 +458,8 @@ class Processor:
judge_model_name: Optional[str] = None,
dataset_type: str = "aime",
seed: int = 1234,
sampling_config: Optional[Dict[str, Any]] = None
sampling_config: Optional[Dict[str, Any]] = None,
output_file: Optional[Path] = None
):
self.server_url = server_url
self.n_predict = n_predict
@ -459,10 +472,11 @@ class Processor:
self.seed = seed
self.grader = grader or Grader()
self.sampling_config = sampling_config or {"n_predict": n_predict}
self.output_file = output_file or Path("llama-eval-state.json")
self.eval_state = EvalState(
id=dataset_type,
tasks=[dataset_type],
task_states={},
task_states={dataset_type: {}},
sampling_config=self.sampling_config
)
@ -533,8 +547,44 @@ class Processor:
task_state.correct = is_correct
task_state.extracted = extracted
task_state.status = "ok"
except Exception as e:
task_state.status = f"error: {str(e)}"
# Log grader request details for debugging
grader_log = {
"case_id": task_id,
"gold": gold,
"pred": pred_truncated,
"extracted": extracted,
"correct": is_correct,
"grader_type": self.grader.grader_type
}
if self.grader.grader_type == "regex" and self.grader.pattern:
grader_log["pattern"] = self.grader.pattern
if "grader_log" not in self.eval_state.task_states[self.dataset_type]:
self.eval_state.task_states[self.dataset_type]["grader_log"] = []
self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log)
# Initialize cases dict if it doesn't exist
if "cases" not in self.eval_state.task_states[self.dataset_type]:
self.eval_state.task_states[self.dataset_type]["cases"] = {}
# Update eval state with grading details
self.eval_state.task_states[self.dataset_type]["cases"][task_id] = {
"case_id": task_id,
"prompt": prompt,
"gold": gold,
"pred": pred,
"extracted": extracted,
"correct": is_correct,
"status": "ok"
}
# Save eval state to disk after each task
try:
self.dump_state(self.output_file)
except Exception as dump_error:
task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}"
except Exception as processing_error:
task_state.status = f"error: {str(processing_error)}"
return task_state
@ -621,10 +671,13 @@ class Processor:
print(f" Extracted: {task_state.extracted}")
print(f" Status: {task_state.status}")
# Merge existing state with new state to preserve grader_log
existing_state = self.eval_state.task_states.get(self.dataset_type, {})
self.eval_state.task_states[self.dataset_type] = {
"total": total,
"correct": correct,
"cases": task_states
"cases": task_states,
**existing_state
}
print(f"\n{'='*60}")
@ -637,7 +690,6 @@ class Processor:
"""Dump eval state to JSON file"""
with open(output_file, "w") as f:
json.dump(asdict(self.eval_state), f, indent=2)
print(f"\nEval state dumped to {output_file}")
def main():
parser = argparse.ArgumentParser(
@ -785,11 +837,13 @@ def main():
judge_server_url=args.judge_server,
judge_model_name=args.judge_model,
dataset_type=args.dataset,
sampling_config=sampling_config
sampling_config=sampling_config,
output_file=args.output
)
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
processor.dump_state(args.output)
print(f"\nEval state dumped to {args.output}")
if __name__ == "__main__":
main()