From f762a71d56fbde9627d5ef75661a703ce9a3d519 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Feb 2026 10:51:41 +0200 Subject: [PATCH] grader : improve example answers --- examples/llama-eval/IMPLEMENTATION.md | 4 ++- examples/llama-eval/README.md | 2 +- examples/llama-eval/llama-eval.py | 41 ++++++++++++++++++++++----- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/llama-eval/IMPLEMENTATION.md b/examples/llama-eval/IMPLEMENTATION.md index c9542f005d..9ca7972882 100644 --- a/examples/llama-eval/IMPLEMENTATION.md +++ b/examples/llama-eval/IMPLEMENTATION.md @@ -54,7 +54,7 @@ class EvalState: ### Grading Types - **regex**: Built-in patterns for each dataset - **cli**: External script with `--answer` and `--expected` args -- **llm**: LLM-based extraction with configurable server/model +- **llm**: LLM-based extraction with few-shot examples and configurable server/model ## Output Format @@ -81,5 +81,7 @@ Complete eval state with task IDs, correctness, prompts, extracted answers, and - Default seed: 1234 - Default threads: 32 - Prompt truncation: First 43 chars + padding + "..." +- Response truncation: Last 10 lines for grading - GPQA requires LLM grader (returns letter A/B/C/D) - Judge model defaults to evaluated model if not specified +- Sample answers defined in SAMPLE_ANSWERS dict for few-shot learning diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md index 89408db823..8ad3ee2823 100644 --- a/examples/llama-eval/README.md +++ b/examples/llama-eval/README.md @@ -79,7 +79,7 @@ Returns exit code 0 if correct, non-zero if incorrect. ### LLM Grader Uses LLM to extract and compare answers: - Configurable server and model -- Includes problem context in prompt +- Includes few-shot examples from sample answers - Case-insensitive comparison ## Output diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 7396261bff..a45bddf222 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -29,6 +29,24 @@ GRADER_PATTERNS = { "winogrande": r'[A-D]', } +SAMPLE_ANSWERS = { + "aime": [ + "42", + "123", + "999" + ], + "gsm8k": [ + "42", + "123", + "999" + ], + "gpqa": [ + "A", + "B", + "C" + ], +} + TEMPLATE_REGISTRY = { "aime": """{question} Please reason step by step, and put your final answer within \\boxed{{}}. @@ -243,17 +261,19 @@ class Grader: grader_type: str = "llm", grader_script: Optional[str] = None, judge_model_name: Optional[str] = None, - judge_server_url: str = "" + judge_server_url: str = "", + dataset_type: str = "aime" ): self.grader_type = grader_type self.grader_script = grader_script self.judge_model_name = judge_model_name self.judge_server_url = judge_server_url + self.dataset_type = dataset_type self.pattern = self._get_pattern() def _get_pattern(self) -> Optional[str]: if self.grader_type == "regex": - return GRADER_PATTERNS.get("aime") # Default to aime pattern + return GRADER_PATTERNS.get(self.grader_type) # Use grader_type as key return None def _extract_answer_regex(self, pred: str) -> Optional[str]: @@ -305,10 +325,16 @@ class Grader: return False, None def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]: - """Grade using LLM-based extraction""" + """Grade using LLM-based extraction with few-shot examples""" + sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, []) + sample_examples = "\n".join([ + f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers) + ]) + prompt = f"""Extract the answer from this response: -Expected answer: {gold} +Here are some example answers: +{sample_examples} === @@ -334,7 +360,7 @@ Please provide only the extracted answer, nothing else. If there is no clear ans except Exception as e: return False, None - def _truncate_response(self, response: str, max_lines: int = 3) -> str: + def _truncate_response(self, response: str, max_lines: int = 6) -> str: """Keep only last N lines of response""" lines = response.split('\n') return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response @@ -441,7 +467,7 @@ class Processor: task_state.pred = pred # Truncate response to last 2-3 lines for grading - pred_truncated = self.grader._truncate_response(pred, max_lines=3) + pred_truncated = self.grader._truncate_response(pred, max_lines=10) # Grade the response is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt) @@ -673,7 +699,8 @@ def main(): grader = Grader( grader_type=args.grader_type, grader_script=args.grader_script, - judge_model_name=args.judge_model if args.judge_model else args.model + judge_model_name=args.judge_model if args.judge_model else args.model, + dataset_type=args.dataset ) if args.grader_type == "llm" and not args.judge_server: