398 lines
17 KiB
Python
398 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
grpo_example.py — Minimal GRPO training loop using llama-finetune-qlora --grpo-mode
|
|
|
|
Demonstrates the IPC protocol between the Python driver and the C++ subprocess.
|
|
No external dependencies required — only Python stdlib.
|
|
|
|
Usage:
|
|
python3 grpo_example.py \
|
|
--model /path/to/model-q4_k_m.gguf \
|
|
--lora-out /path/to/output-adapter.gguf \
|
|
[--lora /path/to/resume-adapter.gguf] \
|
|
[--binary /path/to/llama-finetune-qlora] \
|
|
[--n-steps 200] \
|
|
[--n-gen 8] \
|
|
[--rank 16]
|
|
|
|
IPC Protocol (stdout from C++ process):
|
|
[QLORA:READY] — process initialised
|
|
[QLORA:PROMPT_REQ:<step>] — C++ requests a prompt for step N
|
|
[QLORA:GEN:<k>/<n>] <text> — one generation (newlines escaped as \\n)
|
|
[QLORA:REWARD_REQ:<n>] — C++ requests N reward scores
|
|
[QLORA:PROGRESS] step=X/Y loss=Z epoch=A/B
|
|
[QLORA:CHECKPOINT] <path>
|
|
[QLORA:DONE] final_loss=X
|
|
[QLORA:ERROR] <message>
|
|
|
|
Python → C++ stdin:
|
|
PROMPT <escaped_text>
|
|
REWARD <r1> <r2> ... <rN> (advantages, 0..1 range)
|
|
STOP (request graceful shutdown)
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
)
|
|
log = logging.getLogger("grpo_example")
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# IPC helpers
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
_IPC_RE = re.compile(r"^\[QLORA:([A-Z_]+)(?::([^\]]*))?\](.*)$")
|
|
|
|
|
|
def escape(text: str) -> str:
|
|
"""Escape newlines and backslashes for single-line IPC transport."""
|
|
return text.replace("\\", "\\\\").replace("\n", "\\n").replace("\r", "\\r")
|
|
|
|
|
|
def unescape(text: str) -> str:
|
|
"""Reverse of escape()."""
|
|
out, i = [], 0
|
|
while i < len(text):
|
|
if text[i] == "\\" and i + 1 < len(text):
|
|
c = text[i + 1]
|
|
if c == "n":
|
|
out.append("\n")
|
|
elif c == "r":
|
|
out.append("\r")
|
|
elif c == "\\":
|
|
out.append("\\")
|
|
else:
|
|
out.append(c)
|
|
i += 2
|
|
else:
|
|
out.append(text[i])
|
|
i += 1
|
|
return "".join(out)
|
|
|
|
|
|
def parse_ipc(line: str) -> Optional[Tuple[str, str, str]]:
|
|
"""
|
|
Parse an IPC line into (msg_type, seq, payload).
|
|
Returns None for non-IPC lines (model output, log lines, etc.).
|
|
"""
|
|
m = _IPC_RE.match(line.strip())
|
|
if not m:
|
|
return None
|
|
return m.group(1), (m.group(2) or ""), m.group(3).strip()
|
|
|
|
|
|
def read_ipc(proc: subprocess.Popen, timeout: float = 120.0) -> Optional[Tuple[str, str, str]]:
|
|
"""
|
|
Read lines from proc.stdout until an IPC message arrives.
|
|
Non-IPC lines (model output, C++ logs leaked to stdout) are printed.
|
|
Returns None on EOF.
|
|
Raises TimeoutError if nothing arrives within `timeout` seconds.
|
|
"""
|
|
deadline = time.monotonic() + timeout
|
|
while True:
|
|
remaining = deadline - time.monotonic()
|
|
if remaining <= 0:
|
|
raise TimeoutError(f"No IPC message within {timeout:.0f}s")
|
|
|
|
line = proc.stdout.readline()
|
|
if not line:
|
|
return None # EOF
|
|
|
|
line = line.rstrip("\n")
|
|
parsed = parse_ipc(line)
|
|
if parsed:
|
|
return parsed
|
|
# Non-IPC — C++ sometimes leaks timing/debug lines to stdout.
|
|
# Print them so the user can see what's happening.
|
|
print(f" [cpp] {line}", file=sys.stderr)
|
|
|
|
|
|
def write_cmd(proc: subprocess.Popen, cmd: str):
|
|
"""Write one command line to the subprocess stdin."""
|
|
try:
|
|
proc.stdin.write(cmd + "\n")
|
|
proc.stdin.flush()
|
|
except BrokenPipeError:
|
|
raise RuntimeError("C++ subprocess stdin closed — did it crash?")
|
|
|
|
|
|
def wait_for(proc: subprocess.Popen, expected: str, timeout: float = 120.0) -> Tuple[str, str, str]:
|
|
"""Block until the expected IPC message type arrives."""
|
|
deadline = time.monotonic() + timeout
|
|
while True:
|
|
remaining = deadline - time.monotonic()
|
|
if remaining <= 0:
|
|
raise TimeoutError(f"Timed out waiting for [{expected}]")
|
|
parsed = read_ipc(proc, timeout=remaining)
|
|
if parsed is None:
|
|
raise RuntimeError(f"Subprocess exited before sending [{expected}]")
|
|
msg_type, seq, payload = parsed
|
|
if msg_type == expected:
|
|
return msg_type, seq, payload
|
|
log.debug("Ignoring unexpected IPC (%s) while waiting for %s", msg_type, expected)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Advantage normalisation (GRPO)
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
def normalise_rewards(rewards: List[float]) -> List[float]:
|
|
"""
|
|
Group-relative advantage normalisation: subtract mean, divide by std.
|
|
Clipped to [0, 1] so the C++ side always receives values in that range.
|
|
|
|
All-equal rewards → uniform 0.5 (no signal, but no NaN either).
|
|
"""
|
|
if len(rewards) == 0:
|
|
return []
|
|
mean = sum(rewards) / len(rewards)
|
|
variance = sum((r - mean) ** 2 for r in rewards) / len(rewards)
|
|
std = math.sqrt(variance) if variance > 1e-8 else 1.0
|
|
|
|
normalised = [(r - mean) / std for r in rewards]
|
|
# Shift to [0,1]: z-scores typically lie in [-3, +3]
|
|
clipped = [max(0.0, min(1.0, 0.5 + z / 6.0)) for z in normalised]
|
|
return clipped
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Example prompt / reward providers
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
# Replace these with your own logic.
|
|
|
|
_EXAMPLE_PROMPTS = [
|
|
"Explain the concept of gradient descent in one sentence.",
|
|
"What is the capital of France?",
|
|
"Write a haiku about machine learning.",
|
|
"Describe the difference between SFT and RLHF.",
|
|
"What does GRPO stand for?",
|
|
]
|
|
|
|
|
|
def get_prompt(step: int) -> str:
|
|
"""Return a prompt for the given training step (0-indexed)."""
|
|
return _EXAMPLE_PROMPTS[step % len(_EXAMPLE_PROMPTS)]
|
|
|
|
|
|
def score_generations(prompt: str, generations: List[str]) -> List[float]:
|
|
"""
|
|
Score a list of model generations for the given prompt.
|
|
Returns a list of raw reward scores (any numeric range; will be normalised).
|
|
|
|
This example uses a trivial heuristic: longer, more varied responses
|
|
score higher. Replace with your actual reward model / verifier.
|
|
"""
|
|
scores = []
|
|
for gen in generations:
|
|
words = gen.split()
|
|
# Simple heuristics: length + lexical diversity
|
|
length_score = min(1.0, len(words) / 50.0)
|
|
vocab_score = min(1.0, len(set(words)) / max(1, len(words)))
|
|
scores.append(0.6 * length_score + 0.4 * vocab_score)
|
|
return scores
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Main GRPO loop
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
def run_grpo(args: argparse.Namespace):
|
|
# Resolve binary
|
|
binary = Path(args.binary)
|
|
if not binary.exists():
|
|
log.error("Binary not found: %s", binary)
|
|
sys.exit(1)
|
|
|
|
# Build command
|
|
cmd = [
|
|
str(binary),
|
|
"--model", args.model,
|
|
"--lora-out", args.lora_out,
|
|
"--lora-rank", str(args.rank),
|
|
"--lora-alpha", str(args.rank // 2),
|
|
"-c", str(args.ctx_size),
|
|
"-b", str(args.ctx_size),
|
|
"-ub", "512",
|
|
"-ngl", str(args.ngl),
|
|
"-lr", str(args.lr),
|
|
"--seed", str(args.seed),
|
|
"--grad-checkpoint","48",
|
|
"--shuffle-dataset",
|
|
"--grpo-mode",
|
|
"--n-gen", str(args.n_gen),
|
|
"--n-steps", str(args.n_steps),
|
|
"--grpo-temp", str(args.temperature),
|
|
"--grpo-max-tokens",str(args.max_tokens),
|
|
]
|
|
|
|
if args.lora:
|
|
cmd += ["--lora", args.lora]
|
|
|
|
if args.save_every > 0:
|
|
cmd += ["--save-every", str(args.save_every)]
|
|
|
|
log.info("Launching: %s", " ".join(cmd))
|
|
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=sys.stderr, # C++ debug/timing logs go directly to our stderr
|
|
text=True,
|
|
bufsize=1,
|
|
)
|
|
|
|
try:
|
|
_grpo_loop(proc, args)
|
|
except KeyboardInterrupt:
|
|
log.info("Interrupted — requesting graceful stop")
|
|
try:
|
|
write_cmd(proc, "STOP")
|
|
except Exception:
|
|
pass
|
|
except Exception as e:
|
|
log.error("GRPO loop error: %s", e)
|
|
proc.kill()
|
|
raise
|
|
finally:
|
|
try:
|
|
proc.stdin.close()
|
|
except Exception:
|
|
pass
|
|
rc = proc.wait(timeout=30)
|
|
if rc not in (0, None):
|
|
log.warning("Subprocess exited with code %d", rc)
|
|
|
|
|
|
def _grpo_loop(proc: subprocess.Popen, args: argparse.Namespace):
|
|
# ── Wait for READY ──────────────────────────────────────────────────────
|
|
log.info("Waiting for subprocess to initialise (model load can take a minute)…")
|
|
wait_for(proc, "READY", timeout=300)
|
|
log.info("Subprocess ready.")
|
|
|
|
current_prompt: str = ""
|
|
generations: List[str] = []
|
|
step = 0
|
|
|
|
while True:
|
|
parsed = read_ipc(proc, timeout=600)
|
|
if parsed is None:
|
|
log.info("Subprocess exited (EOF).")
|
|
break
|
|
|
|
msg_type, seq, payload = parsed
|
|
|
|
# ── PROMPT_REQ ──────────────────────────────────────────────────────
|
|
if msg_type == "PROMPT_REQ":
|
|
step = int(seq) if seq else step + 1
|
|
current_prompt = get_prompt(step - 1)
|
|
generations = []
|
|
log.debug("Step %d — sending prompt: %s", step, current_prompt[:60])
|
|
write_cmd(proc, f"PROMPT {escape(current_prompt)}")
|
|
|
|
# ── GEN ─────────────────────────────────────────────────────────────
|
|
elif msg_type == "GEN":
|
|
# seq = "k/n"
|
|
parts = seq.split("/")
|
|
k = int(parts[0])
|
|
n = int(parts[1]) if len(parts) > 1 else args.n_gen
|
|
text = unescape(payload)
|
|
generations.append(text)
|
|
log.debug(" Generation %d/%d: %s…", k, n, text[:60].replace("\n", "↵"))
|
|
|
|
# ── REWARD_REQ ──────────────────────────────────────────────────────
|
|
elif msg_type == "REWARD_REQ":
|
|
n_expected = int(seq) if seq else len(generations)
|
|
if len(generations) != n_expected:
|
|
log.warning(
|
|
"REWARD_REQ asked for %d rewards but collected %d generations",
|
|
n_expected, len(generations),
|
|
)
|
|
|
|
raw_rewards = score_generations(current_prompt, generations)
|
|
advantages = normalise_rewards(raw_rewards)
|
|
|
|
reward_str = " ".join(f"{a:.6f}" for a in advantages)
|
|
log.debug(" Rewards (raw): %s", [f"{r:.3f}" for r in raw_rewards])
|
|
log.debug(" Advantages: %s", [f"{a:.3f}" for a in advantages])
|
|
write_cmd(proc, f"REWARD {reward_str}")
|
|
|
|
# ── PROGRESS ────────────────────────────────────────────────────────
|
|
elif msg_type == "PROGRESS":
|
|
# Format: step=X/Y loss=Z epoch=A/B
|
|
sm = re.search(r"step=(\d+)(?:/(\d+))?", payload)
|
|
lm = re.search(r"loss=([\d.]+)", payload)
|
|
step_str = f"{sm.group(1)}/{sm.group(2)}" if sm and sm.group(2) else (sm.group(1) if sm else "?")
|
|
loss_str = lm.group(1) if lm else "?"
|
|
print(f" step {step_str} loss {loss_str}", flush=True)
|
|
|
|
# ── CHECKPOINT ──────────────────────────────────────────────────────
|
|
elif msg_type == "CHECKPOINT":
|
|
log.info("Checkpoint saved: %s", payload.strip())
|
|
|
|
# ── DONE ────────────────────────────────────────────────────────────
|
|
elif msg_type == "DONE":
|
|
m = re.search(r"final_loss=([\d.]+)", payload)
|
|
loss = m.group(1) if m else "?"
|
|
log.info("Training complete. final_loss=%s", loss)
|
|
break
|
|
|
|
# ── ERROR ────────────────────────────────────────────────────────────
|
|
elif msg_type == "ERROR":
|
|
log.error("C++ process error: %s", payload.strip())
|
|
raise RuntimeError(f"Training failed: {payload.strip()}")
|
|
|
|
else:
|
|
log.debug("Unknown IPC message: [%s] seq=%r payload=%r", msg_type, seq, payload)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# CLI
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
# Default binary: build/bin/ relative to this script's repo root
|
|
script_dir = Path(__file__).resolve().parent
|
|
repo_root = script_dir.parents[1] # examples/qlora_training → llama.cpp root
|
|
default_bin = repo_root / "build" / "bin" / "llama-finetune-qlora"
|
|
|
|
p = argparse.ArgumentParser(
|
|
description="Minimal GRPO training loop via llama-finetune-qlora --grpo-mode",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
p.add_argument("--model", required=True, help="Base GGUF model path")
|
|
p.add_argument("--lora-out", required=True, help="Output adapter GGUF path")
|
|
p.add_argument("--lora", default=None, help="Resume from existing adapter GGUF")
|
|
p.add_argument("--binary", default=str(default_bin), help="Path to llama-finetune-qlora binary")
|
|
p.add_argument("--rank", type=int, default=16, help="LoRA rank")
|
|
p.add_argument("--n-steps", type=int, default=200, help="Number of GRPO steps")
|
|
p.add_argument("--n-gen", type=int, default=8, help="Generations per prompt")
|
|
p.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
|
p.add_argument("--ctx-size", type=int, default=4096, help="Context window")
|
|
p.add_argument("--ngl", type=int, default=999, help="GPU layers (-ngl)")
|
|
p.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
|
p.add_argument("--max-tokens", type=int, default=512, help="Max tokens per generation")
|
|
p.add_argument("--save-every", type=int, default=0, help="Save checkpoint every N steps (0=off)")
|
|
p.add_argument("--seed", type=int, default=42, help="RNG seed")
|
|
p.add_argument("--verbose", action="store_true", help="Enable DEBUG logging")
|
|
return p.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
run_grpo(args)
|