grader : improve example answers
This commit is contained in:
parent
73e61d5b75
commit
f762a71d56
|
|
@ -54,7 +54,7 @@ class EvalState:
|
|||
### Grading Types
|
||||
- **regex**: Built-in patterns for each dataset
|
||||
- **cli**: External script with `--answer` and `--expected` args
|
||||
- **llm**: LLM-based extraction with configurable server/model
|
||||
- **llm**: LLM-based extraction with few-shot examples and configurable server/model
|
||||
|
||||
## Output Format
|
||||
|
||||
|
|
@ -81,5 +81,7 @@ Complete eval state with task IDs, correctness, prompts, extracted answers, and
|
|||
- Default seed: 1234
|
||||
- Default threads: 32
|
||||
- Prompt truncation: First 43 chars + padding + "..."
|
||||
- Response truncation: Last 10 lines for grading
|
||||
- GPQA requires LLM grader (returns letter A/B/C/D)
|
||||
- Judge model defaults to evaluated model if not specified
|
||||
- Sample answers defined in SAMPLE_ANSWERS dict for few-shot learning
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ Returns exit code 0 if correct, non-zero if incorrect.
|
|||
### LLM Grader
|
||||
Uses LLM to extract and compare answers:
|
||||
- Configurable server and model
|
||||
- Includes problem context in prompt
|
||||
- Includes few-shot examples from sample answers
|
||||
- Case-insensitive comparison
|
||||
|
||||
## Output
|
||||
|
|
|
|||
|
|
@ -29,6 +29,24 @@ GRADER_PATTERNS = {
|
|||
"winogrande": r'[A-D]',
|
||||
}
|
||||
|
||||
SAMPLE_ANSWERS = {
|
||||
"aime": [
|
||||
"42",
|
||||
"123",
|
||||
"999"
|
||||
],
|
||||
"gsm8k": [
|
||||
"42",
|
||||
"123",
|
||||
"999"
|
||||
],
|
||||
"gpqa": [
|
||||
"A",
|
||||
"B",
|
||||
"C"
|
||||
],
|
||||
}
|
||||
|
||||
TEMPLATE_REGISTRY = {
|
||||
"aime": """{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
|
|
@ -243,17 +261,19 @@ class Grader:
|
|||
grader_type: str = "llm",
|
||||
grader_script: Optional[str] = None,
|
||||
judge_model_name: Optional[str] = None,
|
||||
judge_server_url: str = ""
|
||||
judge_server_url: str = "",
|
||||
dataset_type: str = "aime"
|
||||
):
|
||||
self.grader_type = grader_type
|
||||
self.grader_script = grader_script
|
||||
self.judge_model_name = judge_model_name
|
||||
self.judge_server_url = judge_server_url
|
||||
self.dataset_type = dataset_type
|
||||
self.pattern = self._get_pattern()
|
||||
|
||||
def _get_pattern(self) -> Optional[str]:
|
||||
if self.grader_type == "regex":
|
||||
return GRADER_PATTERNS.get("aime") # Default to aime pattern
|
||||
return GRADER_PATTERNS.get(self.grader_type) # Use grader_type as key
|
||||
return None
|
||||
|
||||
def _extract_answer_regex(self, pred: str) -> Optional[str]:
|
||||
|
|
@ -305,10 +325,16 @@ class Grader:
|
|||
return False, None
|
||||
|
||||
def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using LLM-based extraction"""
|
||||
"""Grade using LLM-based extraction with few-shot examples"""
|
||||
sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, [])
|
||||
sample_examples = "\n".join([
|
||||
f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers)
|
||||
])
|
||||
|
||||
prompt = f"""Extract the answer from this response:
|
||||
|
||||
Expected answer: {gold}
|
||||
Here are some example answers:
|
||||
{sample_examples}
|
||||
|
||||
===
|
||||
|
||||
|
|
@ -334,7 +360,7 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
|
|||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
def _truncate_response(self, response: str, max_lines: int = 3) -> str:
|
||||
def _truncate_response(self, response: str, max_lines: int = 6) -> str:
|
||||
"""Keep only last N lines of response"""
|
||||
lines = response.split('\n')
|
||||
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
|
||||
|
|
@ -441,7 +467,7 @@ class Processor:
|
|||
task_state.pred = pred
|
||||
|
||||
# Truncate response to last 2-3 lines for grading
|
||||
pred_truncated = self.grader._truncate_response(pred, max_lines=3)
|
||||
pred_truncated = self.grader._truncate_response(pred, max_lines=10)
|
||||
|
||||
# Grade the response
|
||||
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
|
||||
|
|
@ -673,7 +699,8 @@ def main():
|
|||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_script=args.grader_script,
|
||||
judge_model_name=args.judge_model if args.judge_model else args.model
|
||||
judge_model_name=args.judge_model if args.judge_model else args.model,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
|
||||
if args.grader_type == "llm" and not args.judge_server:
|
||||
|
|
|
|||
Loading…
Reference in New Issue