grade : improve regex + logs
This commit is contained in:
parent
52759bf078
commit
db10dda1f3
|
|
@ -325,18 +325,30 @@ class Grader:
|
|||
|
||||
def _get_pattern(self) -> Optional[str]:
|
||||
if self.grader_type == "regex":
|
||||
return GRADER_PATTERNS.get(self.grader_type) # Use grader_type as key
|
||||
return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key
|
||||
return None
|
||||
|
||||
def _extract_answer_regex(self, pred: str) -> Optional[str]:
|
||||
"""Extract answer using regex pattern"""
|
||||
if not self.pattern:
|
||||
return None
|
||||
|
||||
# For AIME datasets, prioritize boxed answers
|
||||
if self.dataset_type in ["aime", "aime2025"]:
|
||||
boxed_pattern = r'\\boxed{([^}]+)}'
|
||||
boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE)
|
||||
if boxed_matches:
|
||||
# Return the last boxed answer found (most likely the final answer)
|
||||
return boxed_matches[-1].strip()
|
||||
|
||||
# For other datasets, search for numbers from the end of the text
|
||||
# This prioritizes numbers that appear later in the response
|
||||
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
for match in matches:
|
||||
# Process matches from end to start
|
||||
for match in reversed(matches):
|
||||
if isinstance(match, tuple):
|
||||
match = match[0] if match[0] else match[1]
|
||||
extracted = match.strip()
|
||||
|
|
@ -446,7 +458,8 @@ class Processor:
|
|||
judge_model_name: Optional[str] = None,
|
||||
dataset_type: str = "aime",
|
||||
seed: int = 1234,
|
||||
sampling_config: Optional[Dict[str, Any]] = None
|
||||
sampling_config: Optional[Dict[str, Any]] = None,
|
||||
output_file: Optional[Path] = None
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
|
|
@ -459,10 +472,11 @@ class Processor:
|
|||
self.seed = seed
|
||||
self.grader = grader or Grader()
|
||||
self.sampling_config = sampling_config or {"n_predict": n_predict}
|
||||
self.output_file = output_file or Path("llama-eval-state.json")
|
||||
self.eval_state = EvalState(
|
||||
id=dataset_type,
|
||||
tasks=[dataset_type],
|
||||
task_states={},
|
||||
task_states={dataset_type: {}},
|
||||
sampling_config=self.sampling_config
|
||||
)
|
||||
|
||||
|
|
@ -533,8 +547,44 @@ class Processor:
|
|||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.status = "ok"
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
||||
# Log grader request details for debugging
|
||||
grader_log = {
|
||||
"case_id": task_id,
|
||||
"gold": gold,
|
||||
"pred": pred_truncated,
|
||||
"extracted": extracted,
|
||||
"correct": is_correct,
|
||||
"grader_type": self.grader.grader_type
|
||||
}
|
||||
if self.grader.grader_type == "regex" and self.grader.pattern:
|
||||
grader_log["pattern"] = self.grader.pattern
|
||||
if "grader_log" not in self.eval_state.task_states[self.dataset_type]:
|
||||
self.eval_state.task_states[self.dataset_type]["grader_log"] = []
|
||||
self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log)
|
||||
|
||||
# Initialize cases dict if it doesn't exist
|
||||
if "cases" not in self.eval_state.task_states[self.dataset_type]:
|
||||
self.eval_state.task_states[self.dataset_type]["cases"] = {}
|
||||
|
||||
# Update eval state with grading details
|
||||
self.eval_state.task_states[self.dataset_type]["cases"][task_id] = {
|
||||
"case_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"pred": pred,
|
||||
"extracted": extracted,
|
||||
"correct": is_correct,
|
||||
"status": "ok"
|
||||
}
|
||||
|
||||
# Save eval state to disk after each task
|
||||
try:
|
||||
self.dump_state(self.output_file)
|
||||
except Exception as dump_error:
|
||||
task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}"
|
||||
except Exception as processing_error:
|
||||
task_state.status = f"error: {str(processing_error)}"
|
||||
|
||||
return task_state
|
||||
|
||||
|
|
@ -621,10 +671,13 @@ class Processor:
|
|||
print(f" Extracted: {task_state.extracted}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
# Merge existing state with new state to preserve grader_log
|
||||
existing_state = self.eval_state.task_states.get(self.dataset_type, {})
|
||||
self.eval_state.task_states[self.dataset_type] = {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"cases": task_states
|
||||
"cases": task_states,
|
||||
**existing_state
|
||||
}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
|
|
@ -637,7 +690,6 @@ class Processor:
|
|||
"""Dump eval state to JSON file"""
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(asdict(self.eval_state), f, indent=2)
|
||||
print(f"\nEval state dumped to {output_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -785,11 +837,13 @@ def main():
|
|||
judge_server_url=args.judge_server,
|
||||
judge_model_name=args.judge_model,
|
||||
dataset_type=args.dataset,
|
||||
sampling_config=sampling_config
|
||||
sampling_config=sampling_config,
|
||||
output_file=args.output
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
processor.dump_state(args.output)
|
||||
print(f"\nEval state dumped to {args.output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue