261 lines
11 KiB
Python
261 lines
11 KiB
Python
"""
|
|
gguf-prune: REAP-based expert pruning directly on a GGUF file.
|
|
|
|
Slices the expert dimension of the four stacked MoE weight tensors per layer:
|
|
blk.{il}.ffn_up_exps [n_embd, intermediate, n_experts]
|
|
blk.{il}.ffn_down_exps [intermediate, n_embd, n_experts]
|
|
blk.{il}.ffn_gate_inp [n_embd, n_experts]
|
|
blk.{il}.ffn_exp_probs_b [n_experts] (score-correction bias, if present)
|
|
|
|
Quantized blocks (Q4_K, Q6_K, …) are preserved as raw bytes — slicing the
|
|
expert axis (last dim) is safe because each expert is independently quantised
|
|
in ggml, so dropping experts = dropping whole quantisation blocks.
|
|
|
|
Metadata patched:
|
|
{arch}.expert_count → keep_n
|
|
(expert_used_count = top-k routing k, NOT touched)
|
|
|
|
Usage:
|
|
# keep top 20% of experts (26/128) per MoE layer
|
|
python gguf_prune.py \\
|
|
--input nemotron.gguf \\
|
|
--stats expert_stats.json \\
|
|
--output nemotron-pruned.gguf \\
|
|
--keep_ratio 0.20
|
|
|
|
# or keep an absolute number
|
|
python gguf_prune.py \\
|
|
--input nemotron.gguf \\
|
|
--stats expert_stats.json \\
|
|
--output nemotron-pruned.gguf \\
|
|
--keep_n 32
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from gguf import GGUFReader, GGUFWriter, GGUFValueType
|
|
|
|
|
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
|
|
# Base tensor names that carry the expert dimension (last axis in ggml layout).
|
|
# Some GGUFs append parameter tails like ".weight" / ".bias".
|
|
EXPERT_BASE_SUFFIXES = {
|
|
"ffn_up_exps",
|
|
"ffn_down_exps",
|
|
"ffn_gate_inp",
|
|
}
|
|
|
|
|
|
def is_expert_suffix(suffix: str) -> bool:
|
|
"""Return True if a tensor suffix is one of the MoE expert tensors to prune."""
|
|
if suffix in ("ffn_exp_probs_b", "exp_probs_b", "exp_probs_b.bias"):
|
|
return True
|
|
return any(suffix == base or suffix.startswith(base + ".") for base in EXPERT_BASE_SUFFIXES)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def layer_and_suffix(name: str) -> tuple[int, str] | tuple[None, None]:
|
|
m = re.match(r"blk\.(\d+)\.(.+)$", name)
|
|
if m:
|
|
return int(m.group(1)), m.group(2)
|
|
return None, None
|
|
|
|
|
|
def pick_experts(layer_stats: dict, keep_n: int) -> list[int]:
|
|
"""
|
|
Return sorted indices of the top `keep_n` experts by REAP score.
|
|
Falls back to 'importance_score' (weighted frequency) if 'reap' absent.
|
|
"""
|
|
if "reap" in layer_stats:
|
|
scores = np.array(layer_stats["reap"], dtype=np.float64)
|
|
elif "importance_score" in layer_stats:
|
|
scores = np.array(layer_stats["importance_score"], dtype=np.float64)
|
|
else:
|
|
raise KeyError(
|
|
"Layer stats has neither 'reap' nor 'importance_score'. "
|
|
"Run expert-profile / nemotron_reap.py profile first."
|
|
)
|
|
return sorted(np.argsort(scores)[-keep_n:].tolist())
|
|
|
|
|
|
def slice_expert_axis(data: np.ndarray, keep: list[int]) -> np.ndarray:
|
|
"""
|
|
Slice the expert axis of reader tensor data keeping only `keep` indices.
|
|
|
|
GGUFReader reshapes tensors to NumPy with reversed ggml dims, so for MoE
|
|
tensors where experts are the last ggml dim, expert is axis 0 in `data`.
|
|
This also preserves quantized row-byte alignment (axis -1 is byte-packed
|
|
rows for quantized tensors and must not be sliced for expert pruning).
|
|
"""
|
|
return np.take(data, keep, axis=0)
|
|
|
|
|
|
def copy_field(writer: GGUFWriter, field, reader: GGUFReader) -> bool:
|
|
"""Copy a single metadata field to writer. Returns False if skipped."""
|
|
key = field.name
|
|
val_type = field.types[0]
|
|
part = field.parts[-1]
|
|
|
|
if val_type == GGUFValueType.STRING:
|
|
# Preserve raw bytes: GGUF metadata can contain non-UTF8 strings.
|
|
writer.add_key_value(key, bytes(part), GGUFValueType.STRING)
|
|
elif val_type == GGUFValueType.UINT8:
|
|
writer.add_uint8(key, int(part[0]))
|
|
elif val_type == GGUFValueType.INT8:
|
|
writer.add_int8(key, int(part[0]))
|
|
elif val_type == GGUFValueType.UINT16:
|
|
writer.add_uint16(key, int(part[0]))
|
|
elif val_type == GGUFValueType.INT16:
|
|
writer.add_int16(key, int(part[0]))
|
|
elif val_type == GGUFValueType.UINT32:
|
|
writer.add_uint32(key, int(part[0]))
|
|
elif val_type == GGUFValueType.INT32:
|
|
writer.add_int32(key, int(part[0]))
|
|
elif val_type == GGUFValueType.FLOAT32:
|
|
writer.add_float32(key, float(part[0]))
|
|
elif val_type == GGUFValueType.UINT64:
|
|
writer.add_uint64(key, int(part[0]))
|
|
elif val_type == GGUFValueType.INT64:
|
|
writer.add_int64(key, int(part[0]))
|
|
elif val_type == GGUFValueType.FLOAT64:
|
|
writer.add_float64(key, float(part[0]))
|
|
elif val_type == GGUFValueType.BOOL:
|
|
writer.add_bool(key, bool(part[0]))
|
|
elif val_type == GGUFValueType.ARRAY:
|
|
elem_type = field.types[1]
|
|
if elem_type == GGUFValueType.STRING:
|
|
# ReaderField.data stores indices of ARRAY payload items; for
|
|
# STRING arrays this points at each string byte payload.
|
|
vals = [bytes(field.parts[idx]) for idx in field.data]
|
|
writer.add_key_value(key, vals, GGUFValueType.ARRAY, sub_type=GGUFValueType.STRING)
|
|
else:
|
|
# ReaderField.data stores part-indices, not payload values.
|
|
vals = field.contents()
|
|
if not isinstance(vals, list):
|
|
print(f" WARNING: skipping array field {key!r} (unexpected non-list contents)")
|
|
return False
|
|
writer.add_array(key, vals)
|
|
else:
|
|
print(f" WARNING: skipping field {key!r} (unsupported type {val_type})")
|
|
return False
|
|
return True
|
|
|
|
|
|
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description="REAP expert pruning on a GGUF file")
|
|
ap.add_argument("--input", required=True, help="Input .gguf path")
|
|
ap.add_argument("--stats", required=True, help="expert_stats.json from expert-profile")
|
|
ap.add_argument("--output", required=True, help="Output .gguf path")
|
|
ap.add_argument("--keep_ratio", type=float, default=None, help="Fraction to keep, e.g. 0.20")
|
|
ap.add_argument("--keep_n", type=int, default=None, help="Absolute count to keep, e.g. 32")
|
|
ap.add_argument("--n_experts", type=int, default=128, help="Experts per MoE layer in source model")
|
|
args = ap.parse_args()
|
|
|
|
if args.keep_ratio is None and args.keep_n is None:
|
|
ap.error("Provide --keep_ratio or --keep_n")
|
|
if args.keep_ratio is not None and args.keep_n is not None:
|
|
ap.error("Provide --keep_ratio OR --keep_n, not both")
|
|
|
|
keep_n = args.keep_n if args.keep_n is not None else max(1, int(args.n_experts * args.keep_ratio))
|
|
print(f"[gguf-prune] keeping {keep_n}/{args.n_experts} experts per MoE layer")
|
|
|
|
# ── Load stats ─────────────────────────────────────────────────────────────
|
|
with open(args.stats) as f:
|
|
stats = {int(k): v for k, v in json.load(f).items()}
|
|
print(f"[gguf-prune] stats loaded for {len(stats)} MoE layers")
|
|
|
|
# ── Open source GGUF ───────────────────────────────────────────────────────
|
|
print(f"[gguf-prune] reading {args.input}")
|
|
reader = GGUFReader(args.input, mode="r")
|
|
|
|
arch_field = reader.get_field("general.architecture")
|
|
arch = str(bytes(arch_field.parts[-1]), "utf-8") if arch_field else "nemotron_h_moe"
|
|
print(f"[gguf-prune] arch {arch}")
|
|
|
|
expert_count_key = f"{arch}.expert_count"
|
|
|
|
# ── Compute kept indices per layer ─────────────────────────────────────────
|
|
kept: dict[int, list[int]] = {}
|
|
for tensor in reader.tensors:
|
|
il, suffix = layer_and_suffix(tensor.name)
|
|
if il is None or suffix is None or not is_expert_suffix(suffix):
|
|
continue
|
|
if il in kept:
|
|
continue # already computed for this layer
|
|
if il not in stats:
|
|
print(f" Layer {il:3d}: no stats — keeping ALL {args.n_experts} experts")
|
|
kept[il] = list(range(args.n_experts))
|
|
else:
|
|
kept[il] = pick_experts(stats[il], keep_n)
|
|
never = stats[il].get("never_activated", "?")
|
|
crit = "reap" if "reap" in stats[il] else "importance_score"
|
|
print(f" Layer {il:3d}: keep {kept[il][:4]}… never_activated={never} criterion={crit}")
|
|
|
|
# ── Build output GGUF ──────────────────────────────────────────────────────
|
|
print(f"\n[gguf-prune] writing {args.output}")
|
|
writer = GGUFWriter(args.output, arch=arch)
|
|
|
|
# --- metadata: copy all fields, replace expert_count ---
|
|
for field in reader.fields.values():
|
|
# Reader exposes synthetic header fields (GGUF.*) that are not KV
|
|
# metadata and must not be copied back as normal keys.
|
|
if field.name.startswith("GGUF."):
|
|
continue
|
|
# Writer already sets general.architecture from ctor; avoid duplicate warning.
|
|
if field.name in (expert_count_key, "general.architecture"):
|
|
continue # replaced below
|
|
copy_field(writer, field, reader)
|
|
|
|
writer.add_expert_count(keep_n)
|
|
print(f"[gguf-prune] patched {expert_count_key} → {keep_n}")
|
|
|
|
# --- tensors ---
|
|
n_pruned = 0
|
|
for tensor in reader.tensors:
|
|
il, suffix = layer_and_suffix(tensor.name)
|
|
is_expert = il is not None and suffix is not None and is_expert_suffix(suffix)
|
|
|
|
if is_expert:
|
|
assert il is not None
|
|
k = kept[il]
|
|
data = slice_expert_axis(tensor.data, k)
|
|
writer.add_tensor(
|
|
tensor.name,
|
|
data,
|
|
raw_dtype=tensor.tensor_type,
|
|
)
|
|
n_pruned += 1
|
|
else:
|
|
writer.add_tensor(
|
|
tensor.name,
|
|
tensor.data,
|
|
raw_dtype=tensor.tensor_type,
|
|
)
|
|
|
|
writer.write_header_to_file()
|
|
writer.write_kv_data_to_file()
|
|
writer.write_tensors_to_file(progress=True)
|
|
writer.close()
|
|
|
|
out = Path(args.output)
|
|
size_gb = out.stat().st_size / 1024**3
|
|
print(f"\n[gguf-prune] done")
|
|
print(f" Expert tensors sliced : {n_pruned}")
|
|
print(f" MoE layers pruned : {len(kept)}")
|
|
print(f" Experts per layer : {keep_n}/{args.n_experts}")
|
|
print(f" Output size : {size_gb:.2f} GB → {out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|