diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index b433c91d85..0bc2ad34c1 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -39,4 +39,5 @@ else() endif() add_subdirectory(fit-params) add_subdirectory(results) + add_subdirectory(expert-profile) endif() diff --git a/tools/expert-profile/CMakeLists.txt b/tools/expert-profile/CMakeLists.txt new file mode 100644 index 0000000000..859bd77a53 --- /dev/null +++ b/tools/expert-profile/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET llama-expert-profile) +add_executable(${TARGET} expert-profile.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) +endif() diff --git a/tools/expert-profile/expert-profile.cpp b/tools/expert-profile/expert-profile.cpp new file mode 100644 index 0000000000..de381ff1f1 --- /dev/null +++ b/tools/expert-profile/expert-profile.cpp @@ -0,0 +1,506 @@ +/** + * expert-profile: NemotronH MoE expert activation profiler (REAP implementation) + * + * Implements the REAP (Router-weighted Expert Activation Pruning) saliency criterion: + * + * REAP(j) = mean over tokens routed to j of: gate_weight(j,t) * ||expert_output(j,t)||_2 + * + * where expert_output is ffn_moe_down (the FFN output BEFORE gate weighting), + * and gate_weight is ffn_moe_weights (post-softmax routing probability). + * + * Intercepts three tensors per MoE layer via ggml eval callback: + * ffn_moe_topk-{il} [n_expert_used, n_tokens] I32 — which experts were selected + * ffn_moe_weights-{il} [1, n_expert_used, n_tokens] F32 — gate weights (softmax probs) + * ffn_moe_down-{il} [n_embd, n_expert_used, n_tokens] F32 — expert outputs (pre-weighting) + * + * Reference: "REAP: Router-weighted Expert Activation Pruning" (arXiv:2510.13999) + * score = mean_{x in X_j}[ g_j(x) * ||f_j(x)||_2 ] (Equation 9) + * + * Usage: + * llama-expert-profile \ + * -m model.gguf --jsonl training-data.jsonl --output expert_stats.json \ + * [--n-experts 128] [--ctx-size 16384] [-ngl 32] [-t 24] [--save-every 1] + */ + +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml-backend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ─── Per-layer stats ────────────────────────────────────────────────────────── + +struct LayerStats { + int64_t n_experts = 0; + int64_t total_tokens = 0; // tokens processed through this layer + + // Frequency / weighted-frequency (kept for reference/comparison) + std::vector activation_counts; // [n_experts] — how many tokens routed here + std::vector weighted_freq_sum; // [n_experts] — sum of gate weights + + // REAP: running sum and count for computing mean(gate_weight * ||expert_out||_2) + std::vector reap_sum; // [n_experts] — sum of g_j(t)*||f_j(t)||_2 + std::vector ean_sum; // [n_experts] — sum of ||f_j(t)||_2 (EAN, no gate) + + void init(int64_t n) { + n_experts = n; + activation_counts.assign(n, 0); + weighted_freq_sum.assign(n, 0.0); + reap_sum.assign(n, 0.0); + ean_sum.assign(n, 0.0); + } + + // Called once we have all three tensors for a batch. + // expert_ids: [n_expert_used * n_tokens] I32 — flat, column-major: [k + t*n_expert_used] + // gate_weights:[n_expert_used * n_tokens] F32 — same layout + // expert_outs: [n_embd * n_expert_used * n_tokens] F32 — layout: [e + k*n_embd + t*n_embd*n_expert_used] + // i.e. for token t, expert-slot k: out vector starts at t*n_embd*n_expert_used + k*n_embd + void add_batch(const int32_t * expert_ids, + const float * gate_weights, + const float * expert_outs, + int64_t n_expert_used, + int64_t n_tok, + int64_t n_embd) { + total_tokens += n_tok; + for (int64_t t = 0; t < n_tok; ++t) { + for (int64_t k = 0; k < n_expert_used; ++k) { + const int64_t flat = k + t * n_expert_used; + const int32_t eid = expert_ids[flat]; + if (eid < 0 || eid >= n_experts) continue; + + const float gw = gate_weights[flat]; + + // L2 norm of expert output vector for this (token, expert-slot) + const float * vec = expert_outs + t * n_embd * n_expert_used + k * n_embd; + double norm2 = 0.0; + for (int64_t d = 0; d < n_embd; ++d) { + norm2 += (double)vec[d] * (double)vec[d]; + } + const double norm = std::sqrt(norm2); + + activation_counts [eid] += 1; + weighted_freq_sum [eid] += gw; + reap_sum [eid] += gw * norm; // REAP numerator + ean_sum [eid] += norm; // EAN numerator + } + } + } +}; + +// ─── Collector ──────────────────────────────────────────────────────────────── + +struct ExpertCollector { + int64_t n_experts = 128; + + std::map layer_stats; + std::mutex mtx; + + // We need all three tensors before we can compute REAP. + // They arrive in order: topk → weights → down (per the graph build order). + // Store pending topk+weights until down arrives. + struct PendingBatch { + int64_t n_expert_used = 0; + int64_t n_tokens = 0; + std::vector expert_ids; // [n_expert_used * n_tokens] + std::vector gate_weights; // [n_expert_used * n_tokens] + bool has_topk = false; + bool has_weights = false; + }; + std::map pending; // layer_idx → pending + + // Strip device prefix/suffix: "CUDA0#ffn_moe_down-5#0" → "ffn_moe_down-5" + static std::string clean_name(const char * raw) { + const char * p = strchr(raw, '#'); + if (p) { + ++p; + const char * q = strchr(p, '#'); + return q ? std::string(p, q - p) : std::string(p); + } + return raw; + } + + bool wants(struct ggml_tensor * t) { + if (!t->name[0]) return false; + const std::string n = clean_name(t->name); + return (n.compare(0, 13, "ffn_moe_topk-") == 0 || + n.compare(0, 16, "ffn_moe_weights-") == 0 || + n.compare(0, 13, "ffn_moe_down-") == 0); + } + + bool on_tensor(struct ggml_tensor * t) { + const std::string name = clean_name(t->name); + + // Identify tensor type and layer + int il = -1; + bool is_topk = false; + bool is_weights = false; + bool is_down = false; + + if (name.compare(0, 13, "ffn_moe_topk-") == 0) { il = atoi(name.c_str() + 13); is_topk = true; } + else if (name.compare(0, 16, "ffn_moe_weights-") == 0) { il = atoi(name.c_str() + 16); is_weights = true; } + else if (name.compare(0, 13, "ffn_moe_down-") == 0) { il = atoi(name.c_str() + 13); is_down = true; } + else return true; + + if (il < 0) return true; + + // Copy tensor data from (possibly GPU) buffer to host + const size_t nbytes = ggml_nbytes(t); + std::vector buf(nbytes); + ggml_backend_tensor_get(t, buf.data(), 0, nbytes); + + std::lock_guard lk(mtx); + PendingBatch & pb = pending[il]; + + if (is_topk) { + // [n_expert_used, n_tokens] I32 + pb.n_expert_used = t->ne[0]; + pb.n_tokens = t->ne[1]; + pb.expert_ids.resize(pb.n_expert_used * pb.n_tokens); + memcpy(pb.expert_ids.data(), buf.data(), pb.n_expert_used * pb.n_tokens * sizeof(int32_t)); + pb.has_topk = true; + pb.has_weights = false; // reset in case of re-use + + } else if (is_weights) { + // [1, n_expert_used, n_tokens] F32 — flat layout same as topk + if (!pb.has_topk) return true; // shouldn't happen + pb.gate_weights.resize(pb.n_expert_used * pb.n_tokens); + memcpy(pb.gate_weights.data(), buf.data(), pb.n_expert_used * pb.n_tokens * sizeof(float)); + pb.has_weights = true; + + } else if (is_down) { + // [n_embd, n_expert_used, n_tokens] F32 + if (!pb.has_topk || !pb.has_weights) return true; + + const int64_t n_embd = t->ne[0]; + const int64_t n_expert_used = t->ne[1]; + const int64_t n_tokens = t->ne[2]; + + // Sanity check + if (n_expert_used != pb.n_expert_used || n_tokens != pb.n_tokens) { + LOG_ERR("expert-profile: dimension mismatch at layer %d\n", il); + pending.erase(il); + return true; + } + + // Ensure layer stats initialised + auto & ls = layer_stats[il]; + if (ls.n_experts == 0) ls.init(n_experts); + + const float * expert_outs = reinterpret_cast(buf.data()); + ls.add_batch(pb.expert_ids.data(), pb.gate_weights.data(), + expert_outs, n_expert_used, n_tokens, n_embd); + + // Done with this batch for this layer + pending.erase(il); + } + + return true; + } +}; + +// ─── Global collector + C callback ─────────────────────────────────────────── + +static ExpertCollector g_collector; + +static bool expert_eval_callback(struct ggml_tensor * t, bool ask, void * /*user_data*/) { + if (ask) return g_collector.wants(t); + return g_collector.on_tensor(t); +} + +// ─── JSON output ────────────────────────────────────────────────────────────── + +static void save_stats(const std::string & path) { + std::ofstream f(path); + if (!f) { + LOG_ERR("expert-profile: failed to open output file '%s'\n", path.c_str()); + return; + } + + f << "{\n"; + bool first_layer = true; + for (auto & [il, ls] : g_collector.layer_stats) { + if (!first_layer) f << ",\n"; + first_layer = false; + + f << " \"" << il << "\": {\n"; + f << " \"total_tokens\": " << ls.total_tokens << ",\n"; + + // activation_counts + f << " \"activation_counts\": ["; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (i) f << ", "; + f << ls.activation_counts[i]; + } + f << "],\n"; + + // activation_frequency + f << " \"activation_frequency\": ["; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (i) f << ", "; + f << ((ls.total_tokens > 0) ? (double)ls.activation_counts[i] / ls.total_tokens : 0.0); + } + f << "],\n"; + + // avg_gate_weight (weighted_freq_sum / activation_counts) + f << " \"avg_gate_weight\": ["; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (i) f << ", "; + f << ((ls.activation_counts[i] > 0) ? ls.weighted_freq_sum[i] / ls.activation_counts[i] : 0.0); + } + f << "],\n"; + + // ean_mean = ean_sum / activation_counts (EAN criterion, no gate weight) + f << " \"ean_mean\": ["; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (i) f << ", "; + f << ((ls.activation_counts[i] > 0) ? ls.ean_sum[i] / ls.activation_counts[i] : 0.0); + } + f << "],\n"; + + // reap = reap_sum / activation_counts (REAP criterion, Eq.9) + f << " \"reap\": ["; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (i) f << ", "; + f << ((ls.activation_counts[i] > 0) ? ls.reap_sum[i] / ls.activation_counts[i] : 0.0); + } + f << "],\n"; + + // never_activated + int64_t never = 0; + for (int64_t i = 0; i < ls.n_experts; ++i) { + if (ls.activation_counts[i] == 0) ++never; + } + f << " \"never_activated\": " << never << "\n"; + f << " }"; + } + f << "\n}\n"; + + LOG_INF("expert-profile: stats saved to '%s' (%zu MoE layers)\n", + path.c_str(), g_collector.layer_stats.size()); +} + +// ─── JSONL input ────────────────────────────────────────────────────────────── + +struct JsonPair { std::string prompt, response; }; + +static bool json_get_string(const std::string & line, const std::string & key, std::string & out) { + std::string search = "\"" + key + "\""; + size_t kpos = line.find(search); + if (kpos == std::string::npos) return false; + size_t colon = line.find(':', kpos + search.size()); + if (colon == std::string::npos) return false; + size_t q1 = line.find('"', colon + 1); + if (q1 == std::string::npos) return false; + out.clear(); + for (size_t i = q1 + 1; i < line.size(); ++i) { + if (line[i] == '\\' && i + 1 < line.size()) { + ++i; + switch (line[i]) { + case '"': out += '"'; break; + case '\\': out += '\\'; break; + case 'n': out += '\n'; break; + case 'r': out += '\r'; break; + case 't': out += '\t'; break; + default: out += line[i]; break; + } + } else if (line[i] == '"') { + return true; + } else { + out += line[i]; + } + } + return false; +} + +static std::vector load_jsonl(const std::string & path) { + std::vector pairs; + std::ifstream f(path); + if (!f) { LOG_ERR("expert-profile: cannot open JSONL file '%s'\n", path.c_str()); return pairs; } + std::string line; + while (std::getline(f, line)) { + if (line.empty()) continue; + JsonPair p; + json_get_string(line, "prompt", p.prompt); + json_get_string(line, "response", p.response); + if (!p.prompt.empty() || !p.response.empty()) pairs.push_back(std::move(p)); + } + return pairs; +} + +// ─── Inference loop ─────────────────────────────────────────────────────────── + +static void run_inference(llama_context * ctx, + const llama_model * model, + const std::vector & pairs, + int max_tokens, + const std::string & output_path, + int save_every) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + llama_batch batch = llama_batch_init(max_tokens, 0, 1); + + for (size_t pi = 0; pi < pairs.size(); ++pi) { + const std::string text = pairs[pi].prompt + "\n" + pairs[pi].response; + + std::vector tokens = common_tokenize(ctx, text, add_bos, true); + if ((int)tokens.size() > max_tokens) tokens.resize(max_tokens); + if (tokens.empty()) continue; + + LOG_INF(" [%zu/%zu] %zu tokens\n", pi + 1, pairs.size(), tokens.size()); + + llama_memory_clear(llama_get_memory(ctx), true); + + common_batch_clear(batch); + for (int i = 0; i < (int)tokens.size(); ++i) { + common_batch_add(batch, tokens[i], i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR(" [%zu/%zu] llama_decode failed — skipping\n", pi + 1, pairs.size()); + } + + if (save_every > 0 && (pi + 1) % save_every == 0) { + save_stats(output_path); + } + } + + llama_batch_free(batch); +} + +// ─── CLI ────────────────────────────────────────────────────────────────────── + +int main(int argc, char ** argv) { + std::string model_path; + std::string jsonl_path; + std::string output_path = "expert_stats.json"; + int n_experts = 128; + int ctx_size = 2048; + int n_gpu_layers = 99; + int n_threads = 4; + int save_every = 100; + enum ggml_type kv_type_k = GGML_TYPE_F16; + enum ggml_type kv_type_v = GGML_TYPE_F16; + + auto parse_ggml_type = [](const char * s) -> enum ggml_type { + if (strcmp(s, "f32") == 0) return GGML_TYPE_F32; + if (strcmp(s, "f16") == 0) return GGML_TYPE_F16; + if (strcmp(s, "q8_0") == 0) return GGML_TYPE_Q8_0; + if (strcmp(s, "q4_0") == 0) return GGML_TYPE_Q4_0; + fprintf(stderr, "Unknown KV type '%s', using f16\n", s); return GGML_TYPE_F16; + }; + + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + auto next = [&]() -> const char * { + if (i + 1 >= argc) { fprintf(stderr, "Missing argument for %s\n", argv[i]); exit(1); } + return argv[++i]; + }; + if (a == "-m" || a == "--model") model_path = next(); + else if (a == "--jsonl") jsonl_path = next(); + else if (a == "--output") output_path = next(); + else if (a == "--n-experts") n_experts = atoi(next()); + else if (a == "--ctx-size" || a == "-c") ctx_size = atoi(next()); + else if (a == "-ngl" || a == "--n-gpu-layers") n_gpu_layers = atoi(next()); + else if (a == "-t" || a == "--threads") n_threads = atoi(next()); + else if (a == "--type-k") kv_type_k = parse_ggml_type(next()); + else if (a == "--type-v") kv_type_v = parse_ggml_type(next()); + else if (a == "--save-every") save_every = atoi(next()); + else if (a == "-h" || a == "--help") { + fprintf(stderr, + "\nUsage: %s -m model.gguf --jsonl data.jsonl [options]\n" + " --output PATH Output JSON (default: expert_stats.json)\n" + " --n-experts N Experts per layer (default: 128)\n" + " --ctx-size N Context length (default: 2048)\n" + " -ngl N GPU layers (default: 99)\n" + " -t N CPU threads (default: 4)\n" + " --type-k/v TYPE KV cache type: f32/f16/q8_0/q4_0 (default: f16)\n" + " --save-every N Checkpoint every N samples (default: 100)\n\n", argv[0]); + return 0; + } else { + fprintf(stderr, "Unknown argument: %s\n", a.c_str()); return 1; + } + } + + if (model_path.empty()) { fprintf(stderr, "Error: -m required\n"); return 1; } + if (jsonl_path.empty()) { fprintf(stderr, "Error: --jsonl required\n"); return 1; } + + g_collector.n_experts = n_experts; + + LOG_INF("expert-profile: model = %s\n", model_path.c_str()); + LOG_INF("expert-profile: jsonl = %s\n", jsonl_path.c_str()); + LOG_INF("expert-profile: output = %s\n", output_path.c_str()); + LOG_INF("expert-profile: n_experts = %d\n", n_experts); + LOG_INF("expert-profile: ctx_size = %d\n", ctx_size); + LOG_INF("expert-profile: ngl = %d\n", n_gpu_layers); + LOG_INF("expert-profile: criterion = REAP (gate_weight * ||expert_out||_2)\n"); + + auto pairs = load_jsonl(jsonl_path); + if (pairs.empty()) { LOG_ERR("expert-profile: no pairs loaded\n"); return 1; } + LOG_INF("expert-profile: loaded %zu pairs\n", pairs.size()); + + llama_backend_init(); + + // Suppress INFO/WARN spam (CUDA graph warmup etc.), only pass errors through + llama_log_set([](enum ggml_log_level level, const char * text, void *) { + if (level >= GGML_LOG_LEVEL_ERROR) fputs(text, stderr); + }, nullptr); + + llama_model_params mparams = llama_model_default_params(); + mparams.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), mparams); + if (!model) { LOG_ERR("expert-profile: failed to load model\n"); return 1; } + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = ctx_size; + cparams.n_batch = ctx_size; + cparams.n_ubatch = std::min(ctx_size, 512); + cparams.n_threads = n_threads; + cparams.type_k = kv_type_k; + cparams.type_v = kv_type_v; + cparams.cb_eval = expert_eval_callback; + cparams.cb_eval_user_data = nullptr; + + llama_context * ctx = llama_init_from_model(model, cparams); + if (!ctx) { LOG_ERR("expert-profile: failed to create context\n"); return 1; } + + LOG_INF("expert-profile: running forward passes...\n"); + run_inference(ctx, model, pairs, ctx_size, output_path, save_every); + save_stats(output_path); + + // Summary + LOG_INF("\n MoE layers profiled: %zu\n", g_collector.layer_stats.size()); + for (auto & [il, ls] : g_collector.layer_stats) { + // Find top and bottom REAP expert + int64_t top_e = 0, bot_e = 0; + double top_v = 0.0, bot_v = 1e18; + for (int64_t i = 0; i < ls.n_experts; ++i) { + double v = (ls.activation_counts[i] > 0) ? ls.reap_sum[i] / ls.activation_counts[i] : 0.0; + if (v > top_v) { top_v = v; top_e = i; } + if (v < bot_v) { bot_v = v; bot_e = i; } + } + int64_t never = 0; + for (int64_t i = 0; i < ls.n_experts; ++i) + if (ls.activation_counts[i] == 0) ++never; + LOG_INF(" Layer %3d: tokens=%lld never=%lld reap_top=e%lld(%.4f) reap_bot=e%lld(%.4f)\n", + il, (long long)ls.total_tokens, (long long)never, + (long long)top_e, top_v, (long long)bot_e, bot_v); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return 0; +} diff --git a/tools/moe-pruning/README.md b/tools/moe-pruning/README.md new file mode 100644 index 0000000000..a88499ac43 --- /dev/null +++ b/tools/moe-pruning/README.md @@ -0,0 +1,97 @@ +# MoE Expert Pruning Tools for NemotronH + +REAP-style expert pruning for `NVIDIA-Nemotron-3-Nano-30B-A3B` (and other +NemotronH MoE models), implemented in two complementary ways: + +1. **`tools/expert-profile/`** — C++ profiler built into llama.cpp, collects + REAP scores directly from GGUF inference via the ggml eval callback. +2. **`tools/moe-pruning/`** (this directory) — Python scripts to prune the model + using the collected scores, either on a GGUF file directly or on a + HuggingFace BF16 checkpoint. + +--- + +## Inspiration & Prior Art + +This work is a direct implementation of the **REAP** saliency criterion +introduced in: + +> **REAP the Experts: Why Pruning Prevails for One-Shot MoE Compression** +> Mike Lasby, Ivan Lazarevich, Nish Sinnadurai, Sean Lie, Yani Ioannou, Vithursan Thangarasa +> Cerebras Research, 2025 +> arXiv: https://arxiv.org/abs/2510.13999 +> Code: https://github.com/CerebrasResearch/reap + +The REAP score for expert `j` is (Equation 9 of the paper): + +``` +REAP(j) = mean_{t : j ∈ topk(t)} [ g_j(t) · ‖f_j(t)‖₂ ] +``` + +where `g_j(t)` is the router gate weight and `f_j(t)` is the expert FFN output +(pre-weighting) for token `t`. Experts with the lowest REAP score contribute +least to the layer output and are pruned first. + +The original REAP repo targets HuggingFace models via PyTorch hooks on +standard architectures (Qwen3-MoE, Mixtral, DeepSeek-V2, Llama-4, …). + +**What we added / adapted:** + +- `tools/expert-profile/expert-profile.cpp` — llama.cpp C++ implementation + of REAP that intercepts `ffn_moe_topk`, `ffn_moe_weights`, and `ffn_moe_down` + tensors via `ggml_backend_eval_callback`, enabling REAP profiling on any + GGUF-quantised model (Q4_K_M, Q6_K, etc.) without needing full BF16 VRAM. + +- `gguf_prune.py` — prunes the GGUF file **directly**, slicing the expert axis + of the stacked weight tensors (`ffn_up_exps`, `ffn_down_exps`, `ffn_gate_inp`, + `ffn_exp_probs_b`) and patching `{arch}.expert_count` in the metadata. + Quantised blocks are preserved as raw bytes — no dequantise/requantise step. + +- `nemotron_reap.py` — HuggingFace-based alternative: profiles with 4-bit NF4 + on GPU (phase 1) and prunes the BF16 checkpoint on CPU (phase 2). Adds + NemotronH (`NemotronHForCausalLM`) support that the original REAP repo does + not have. + +--- + +## Recommended Workflow (low-VRAM, e.g. RTX 4060 Ti 16 GB) + +``` +┌─────────────────────────────────────────────┐ +│ Phase 1 — Profile (GPU, GGUF Q4, ~15 GB) │ +│ │ +│ llama-expert-profile │ +│ -m nemotron-Q4_K_M.gguf │ +│ --jsonl sample_calibration.jsonl │ +│ --output expert_stats.json │ +│ -ngl 99 --ctx-size 2048 │ +└───────────────────┬─────────────────────────┘ + │ expert_stats.json +┌───────────────────▼─────────────────────────┐ +│ Phase 2 — Prune (CPU, pure Python, ~2 GB) │ +│ │ +│ python gguf_prune.py │ +│ --input nemotron-Q4_K_M.gguf │ +│ --stats expert_stats.json │ +│ --output nemotron-pruned-26e.gguf │ +│ --keep_ratio 0.20 # 26/128 experts │ +└─────────────────────────────────────────────┘ +``` + +At 20 % keep ratio a ~22 GB Q4_K_M becomes ~4.5 GB. + +--- + +## Files + +| File | Description | +|---|---| +| `gguf_prune.py` | GGUF-native pruner — no GPU needed, preserves quantisation | +| `nemotron_reap.py` | HF-based pruner — 4-bit GPU profile + CPU BF16 prune | +| `build_expert_profile.sh` | Build script for `llama-expert-profile` | +| `run_nemotron_profile.sh` | Example profiling run | +| `run_prune.sh` | Example pruning run | +| `run_convert_quantize.sh` | Convert HF → GGUF and quantise | +| `analyze_stats.py` | Visualise and compare expert stats JSON files | +| `sample_calibration.jsonl` | Sample calibration data (prompt+response pairs) | +| `expert_stats_reap.json` | Example stats output from expert-profile | diff --git a/tools/moe-pruning/analyze_stats.py b/tools/moe-pruning/analyze_stats.py new file mode 100644 index 0000000000..2e0821f323 --- /dev/null +++ b/tools/moe-pruning/analyze_stats.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +analyze_stats.py -- Summarize expert_stats.json and model size projections. +Usage: python analyze_stats.py [stats_file] [--keep 0.5] +""" +import json, statistics, argparse + +parser = argparse.ArgumentParser() +parser.add_argument("stats", nargs="?", default="expert_stats_reap.json") +parser.add_argument("--keep", type=float, default=0.5, help="Fraction of experts to keep (default 0.5)") +args = parser.parse_args() + +with open(args.stats) as f: + data = json.load(f) + +layers = sorted(data.keys(), key=int) +n_layers = len(layers) +keep_ratio = args.keep + +# Detect which scoring field is available (new REAP vs old importance_score) +sample_layer = data[layers[0]] +if "reap" in sample_layer: + score_field = "reap" + score_label = "REAP (gate_weight × ||expert_out||₂)" +elif "importance_score" in sample_layer: + score_field = "importance_score" + score_label = "importance_score (freq × avg_gate_weight) [legacy, no EAN]" +else: + raise ValueError(f"No recognised score field in stats. Keys: {list(sample_layer.keys())}") + +# ── Model architecture constants (Nemotron-3-Nano-30B-A3B) ────────────────── +N_EXPERTS = 128 +N_EXPERT_USED = 6 # top-k per token +N_MOE_LAYERS = 23 +N_TOTAL_LAYERS = 53 +# Approximate parameter counts (bf16, billions) +PARAMS_TOTAL_B = 30.0 +PARAMS_MOE_EXPERTS_B = 22.0 # bulk of MoE weight is in expert FFNs +PARAMS_NON_MOE_B = PARAMS_TOTAL_B - PARAMS_MOE_EXPERTS_B + +# ── Header ────────────────────────────────────────────────────────────────── +print("=" * 70) +print(f" Expert Stats Analysis | file: {args.stats}") +print("=" * 70) + +# ── Profiling completeness ─────────────────────────────────────────────────── +sample_tokens = list(data.values())[0]["total_tokens"] +# Each token activates N_EXPERT_USED experts, sum(activation_counts) = total*top_k +# Approximate samples: total_tokens / avg_tokens_per_sample +# We don't know avg, but can infer: total_tokens / (total_tokens / ctx) ≈ ctx chunks +# Better: just report tokens and note the user knows sample count +print(f"\n── Profiling progress ──────────────────────────────────────────────────") +print(f" MoE layers profiled : {n_layers} / {N_MOE_LAYERS}") +print(f" Tokens processed : {sample_tokens:,} (per layer)") +act_sum = sum(data[layers[0]]["activation_counts"]) +assert abs(act_sum / sample_tokens - N_EXPERT_USED) < 0.01, "unexpected top-k" +print(f" top-k confirmed : {N_EXPERT_USED} (sum activations / tokens = {act_sum/sample_tokens:.1f})") + +# ── Per-layer importance score stats ──────────────────────────────────────── +print(f"\n── Per-layer score distribution [{score_label}]") +print(f" {'Layer':>5} {'Min':>9} {'Max':>9} {'Range':>9} {'CV%':>6} {'Never':>5}") +global_cvs = [] +for k in layers: + d = data[k] + s = d[score_field] + mn, mx = min(s), max(s) + cv = statistics.stdev(s) / statistics.mean(s) * 100 + global_cvs.append(cv) + print(f" {k:>5} {mn:>9.5f} {mx:>9.5f} {mx-mn:>9.5f} {cv:>6.3f}% {d['never_activated']:>5}") + +print(f"\n Mean CV across layers : {statistics.mean(global_cvs):.3f}%") +print(f" (CV < 1% = near-uniform; load-balancing is working as designed)") + +# ── Capacity loss sweep across pruning levels ──────────────────────────────── +# Paper (observer.py): REAP[i] = mean(ean_norm * softmax_router_weight) over tokens +# routed to expert i, averaged via OnlineStatsTracker weighted by expert_frequency. +# Our implementation (llama.cpp): same formula but routing weights are the top-k +# gate weights (post-softmax within top-k), not the full softmax over all 128. +# Impact: our weights are slightly higher than the paper's (renormalized to top-k +# only), but relative expert ranking within a layer should be preserved. +# +# IMPORTANT CAVEAT for this model (Nemotron-3-Nano-30B-A3B): +# The model was trained with a strong load-balancing auxiliary loss, so all 128 +# experts have nearly identical activation frequency (~4.69%) AND nearly identical +# REAP scores (Gini ~0.015, top/bottom ratio ~1.1-1.35x). The score distribution +# is a smooth monotone curve with NO natural elbow or gap. +# +# This means: +# - REAP ranking beats random pruning by only ~1pp in mass terms at keep=33% +# - The cut point boundary (rank 42 vs 43) has near-zero gap in most layers +# - REAP paper results on Qwen3-30B-A3B likely had higher Gini (less tight +# load-balancing or more expert specialization in pre-training) +# - For this model, actual quality loss must be measured via eval, not predicted +# from REAP score variance +# +# Metrics reported: +# - kept_mass%: REAP mass in the KEPT experts as % of total (> keep_ratio% = good) +# - vs_random%: how much more mass the REAP-selected set retains vs a random set +# of the same size (= kept_mass% - keep_ratio%). Positive = REAP wins. +# - Rel.gap: score gap at cut / layer score range. Near 0 = no natural cut point. +# - Gini: inequality of score distribution. ~0.015 here = near-uniform. + +def gini(scores): + """Gini coefficient of a list of non-negative values.""" + n = len(scores) + s = sorted(scores) + total = sum(s) + if total == 0: + return 0.0 + cumsum = 0.0 + for i, v in enumerate(s): + cumsum += (2 * (i + 1) - n - 1) * v + return cumsum / (n * total) + +def layer_stats(scores, n_keep): + """Return capacity metrics for a single layer at a given keep count.""" + n = len(scores) + ranked = sorted(range(n), key=lambda i: scores[i], reverse=True) + total = sum(scores) + kept_mass = sum(scores[i] for i in ranked[:n_keep]) + kept_frac = kept_mass / total if total > 0 else 0.0 # fraction of REAP mass kept + random_frac = n_keep / n # uniform expectation + vs_random = kept_frac - random_frac # positive = REAP beats random + score_range = scores[ranked[0]] - scores[ranked[-1]] + gap = scores[ranked[n_keep - 1]] - (scores[ranked[n_keep]] if n_keep < n else 0) + rel_gap = gap / score_range if score_range > 0 else 0.0 + return kept_frac * 100, vs_random * 100, rel_gap + +# Sweep over a range of keep ratios +sweep_ratios = [0.10, 0.20, 0.25, 0.33, 0.40, 0.50, 0.60, 0.75] +if keep_ratio not in sweep_ratios: + sweep_ratios.append(keep_ratio) +sweep_ratios = sorted(set(sweep_ratios)) + +# Per-layer Gini (fixed, independent of keep ratio) +layer_ginis = {k: gini(data[k][score_field]) for k in layers} +mean_gini = statistics.mean(layer_ginis.values()) +worst_gini_layer = max(layer_ginis, key=lambda k: layer_ginis[k]) + +print(f"\n── Score distribution inequality (Gini coefficient) ────────────────────") +print(f" Gini measures how non-uniform REAP scores are within each layer.") +print(f" Gini=0: all experts identical. Gini=1: one expert dominates.") +print(f" With load-balanced MoE, Gini is small — but any Gini > 0 means") +print(f" REAP ranking beats random pruning.") +print(f"") +print(f" {'Layer':>5} {'Gini':>8} {'Score range':>13} {'Max/Min ratio':>14}") +print(f" {'-'*5} {'-'*8} {'-'*13} {'-'*14}") +for k in layers: + s = data[k][score_field] + mn, mx = min(s), max(s) + g = layer_ginis[k] + ratio_mm = mx / mn if mn > 0 else float('inf') + print(f" {k:>5} {g:>8.5f} {mx-mn:>13.5f} {ratio_mm:>13.3f}x") +print(f"") +print(f" Mean Gini : {mean_gini:.5f} (worst layer: {worst_gini_layer})") + +print(f"\n── Capacity retention sweep ─────────────────────────────────────────────") +print(f" Kept mass% = REAP mass in KEPT experts as % of total (higher = better)") +print(f" vs.rand% = Kept mass% minus uniform baseline (keep_ratio%)") +print(f" Positive = REAP beats random. Magnitude = advantage in pp.") +print(f" Rel.gap = score gap at cut / layer score range (higher = cleaner cut)") +print(f" WARNING: near-zero rel.gap and small vs.rand mean eval is the only ground truth.") +print(f"") +print(f" {'Keep':>5} {'Experts':>7} {'Kept mass%':>11} {'vs.rand%':>9} {'Rel.gap avg':>12} {'Worst layer':>11}") +print(f" {'-'*5} {'-'*7} {'-'*11} {'-'*9} {'-'*12} {'-'*11}") + +sweep_results = {} +for ratio in sweep_ratios: + nk = max(1, round(N_EXPERTS * ratio)) + mass_fracs, excesses, rel_gaps = [], [], [] + worst_excess, worst_layer_id = -999.0, None + for k in layers: + scores = data[k][score_field] + mf, exc, rg = layer_stats(scores, nk) + mass_fracs.append(mf) + excesses.append(exc) + rel_gaps.append(rg) + if exc > worst_excess: + worst_excess = exc + worst_layer_id = k + avg_mf = statistics.mean(mass_fracs) + avg_exc = statistics.mean(excesses) + avg_rg = statistics.mean(rel_gaps) + marker = " <--" if abs(ratio - keep_ratio) < 1e-9 else "" + print(f" {ratio:>5.0%} {nk:>7d} {avg_mf:>10.2f}% {avg_exc:>+9.2f}% {avg_rg:>11.4f} layer {worst_layer_id:>3}{marker}") + sweep_results[ratio] = { + "n_keep": nk, "avg_kept_mass": avg_mf, "avg_vs_random": avg_exc, + "avg_rel_gap": avg_rg, "worst_layer_id": worst_layer_id, "worst_vs_random": worst_excess, + } + +print(f"") +print(f" vs.rand% quantifies REAP's advantage over random pruning in REAP-mass terms.") +print(f" For this model it is small (+0.7 to +1.5pp) due to tight load-balancing.") +print(f" Rel.gap near zero means scores are smooth with no natural cut — any threshold") +print(f" is as defensible as another. Actual quality delta requires empirical eval.") + +# ── Expert keep/prune detail at selected keep_ratio ────────────────────────── +n_keep = max(1, round(N_EXPERTS * keep_ratio)) +n_prune = N_EXPERTS - n_keep + +print(f"\n── Expert pruning detail at keep_ratio={keep_ratio:.0%} ({n_keep} keep / {n_prune} prune per layer) ──") +print(f" {'Layer':>5} {'Kept mass%':>11} {'vs.rand%':>9} {'Rel.gap':>9} {'Min kept':>10} {'Max pruned':>11}") +print(f" {'-'*5} {'-'*11} {'-'*9} {'-'*9} {'-'*10} {'-'*11}") + +layer_results = {} +for k in layers: + scores = data[k][score_field] + ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) + mf, exc, rg = layer_stats(scores, n_keep) + min_kept = scores[ranked[n_keep - 1]] + max_pruned = scores[ranked[n_keep]] if n_prune > 0 else 0 + layer_results[k] = {"mass_frac": mf, "excess": exc, "rel_gap": rg, + "min_kept": min_kept, "max_pruned": max_pruned} + print(f" {k:>5} {mf:>10.2f}% {exc:>+9.2f}% {rg:>9.4f} {min_kept:>10.5f} {max_pruned:>11.5f}") + +avg_mf = statistics.mean(r["mass_frac"] for r in layer_results.values()) +avg_exc = statistics.mean(r["excess"] for r in layer_results.values()) +avg_rg = statistics.mean(r["rel_gap"] for r in layer_results.values()) +print(f" {'AVG':>5} {avg_mf:>10.2f}% {avg_exc:>+9.2f}% {avg_rg:>9.4f}") + +# ── Model size projections ─────────────────────────────────────────────────── +print(f"\n── Model size projections ──────────────────────────────────────────────") + +def model_size(keep): + expert_params = PARAMS_MOE_EXPERTS_B * keep + return PARAMS_NON_MOE_B + expert_params + +original_b = model_size(1.0) +pruned_b = model_size(keep_ratio) +reduction_pct = (1 - pruned_b / original_b) * 100 + +# GGUF sizes at common quant levels (rough: 1B params ≈ quant_bpw/8 GB) +quants = [("Q8_0", 8.0), ("Q5_K_M", 5.5), ("Q4_K_M", 4.5), ("Q3_K_M", 3.35), ("Q2_K", 2.63)] + +print(f" {'':20} {'Original':>10} {'Pruned':>10} {'Saved':>8}") +print(f" {'Parameters (B)':20} {original_b:>10.1f} {pruned_b:>10.1f} {original_b-pruned_b:>8.1f}B") +print(f" {'Reduction':20} {'':>10} {reduction_pct:>9.1f}%") +print() +print(f" Estimated GGUF sizes:") +print(f" {'Quant':10} {'Original':>10} {'Pruned':>10} {'Fits in':>12}") +for name, bpw in quants: + orig_gb = original_b * bpw / 8 + prune_gb = pruned_b * bpw / 8 + # VRAM fit (16GB GPU) + fits = "16GB GPU" if prune_gb <= 15.5 else ("32GB GPU" if prune_gb <= 31 else "CPU/RAM") + print(f" {name:10} {orig_gb:>9.1f}G {prune_gb:>9.1f}G {fits:>12}") + +# ── Active params per token (inference cost) ───────────────────────────────── +print(f"\n── Inference cost (active params per token) ────────────────────────────") +# Active params = non-moe + (n_expert_used/n_experts_kept * moe_expert_params) +# After pruning: router still picks top-k but from n_keep pool +# Active expert params per token = (N_EXPERT_USED / n_keep) * (PARAMS_MOE_EXPERTS_B * keep_ratio) +# But actually active params = N_EXPERT_USED * (params per single expert) +params_per_expert_orig = PARAMS_MOE_EXPERTS_B / N_EXPERTS # B per expert +params_per_expert_pruned = (PARAMS_MOE_EXPERTS_B * keep_ratio) / n_keep # same, just fewer experts + +active_orig = PARAMS_NON_MOE_B + N_EXPERT_USED * params_per_expert_orig * N_MOE_LAYERS / N_TOTAL_LAYERS +active_pruned = PARAMS_NON_MOE_B + N_EXPERT_USED * params_per_expert_pruned * N_MOE_LAYERS / N_TOTAL_LAYERS + +print(f" Original : {active_orig:.2f}B active params/token (same expert size, more choice)") +print(f" Pruned : {active_pruned:.2f}B active params/token (same — top-k still fires {N_EXPERT_USED} experts)") +print(f" Note: active params per token are IDENTICAL — pruning only reduces") +print(f" model file size and memory footprint, not per-token compute.") + +# ── Consistently low-importance experts ────────────────────────────────────── +print(f"\n── Experts consistently ranked low across all layers ───────────────────") +bottom_n = max(1, round(N_EXPERTS * 0.10)) # bottom 10% +low_count = {} +for k in layers: + scores = data[k][score_field] + ranked = sorted(range(len(scores)), key=lambda i: scores[i]) + for eid in ranked[:bottom_n]: + low_count[eid] = low_count.get(eid, 0) + 1 + +consistent = sorted(low_count.items(), key=lambda x: -x[1]) +consistent = [(eid, cnt) for eid, cnt in consistent if cnt >= 3] +print(f" (bottom 10% in >= 3 layers — most dispensable experts globally)") +print(f" Expert ID : layers in bottom 10%") +for eid, cnt in consistent[:20]: + bar = "█" * cnt + print(f" Expert {eid:>3} : {cnt:>2}/{n_layers} {bar}") + +print() +print("=" * 70) diff --git a/tools/moe-pruning/build_expert_profile.sh b/tools/moe-pruning/build_expert_profile.sh new file mode 100644 index 0000000000..0b39604426 --- /dev/null +++ b/tools/moe-pruning/build_expert_profile.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# build_expert_profile.sh +# Builds llama.cpp with the expert-profile tool in WSL2 with CUDA. +# Run this from the tools/moe-pruning/ directory: bash build_expert_profile.sh + +set -e + +LLAMA_SRC="../.." +BUILD_DIR="$LLAMA_SRC/build_expert" + +echo "=== Building llama.cpp + expert-profile tool ===" +echo " Source : $LLAMA_SRC" +echo " Build : $BUILD_DIR" + +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" + +# Configure with CUDA +cmake "$LLAMA_SRC" \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_CUDA=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DCMAKE_CUDA_ARCHITECTURES=86 \ + 2>&1 | tail -20 + +# Build only the expert-profile target (fast) +cmake --build . --target llama-expert-profile --config Release -j$(nproc) + +echo "" +echo "=== Build complete ===" +echo " Binary: $BUILD_DIR/tools/expert-profile/llama-expert-profile" +echo "" +echo "=== Usage ===" +echo " $BUILD_DIR/tools/expert-profile/llama-expert-profile \\" +echo " -m ~/nemotron-3-nano-30b-Q4_K_M.gguf \\" +echo " --jsonl ./sample_calibration.jsonl \\" +echo " --output ./expert_stats_reap.json \\" +echo " --n-experts 128 \\" +echo " --ctx-size 16384 \\" +echo " -ngl 99" diff --git a/tools/moe-pruning/extract_ppl.py b/tools/moe-pruning/extract_ppl.py new file mode 100644 index 0000000000..972a32e99d --- /dev/null +++ b/tools/moe-pruning/extract_ppl.py @@ -0,0 +1,41 @@ +import json, os + +base = os.path.dirname(os.path.abspath(__file__)) + +lines = open(os.path.join(base, 'rwsft-training-data.jsonl'), encoding='utf-8').readlines() +split = int(len(lines) * 0.95) + +train_lines = lines[:split] +val_lines = lines[split:] + +train_out = os.path.join(base, 'ppl-eval-train.txt') +val_out = os.path.join(base, 'ppl-eval-val.txt') + +def fmt(s): + # Full prompt+response so the model is conditioned correctly. + # llama-perplexity scores all tokens, but the prompt PPL is identical + # for base vs adapter — the delta is driven by the response tokens. + prompt = s.get('prompt', '').strip() + response = s.get('response', '').strip() + if not response: + return None + if prompt: + return prompt + '\n' + response + return response + +with open(train_out, 'w', encoding='utf-8') as f: + for line in train_lines: + text = fmt(json.loads(line)) + if text: + f.write(text + '\n\n') + +with open(val_out, 'w', encoding='utf-8') as f: + for line in val_lines: + text = fmt(json.loads(line)) + if text: + f.write(text + '\n\n') + +train_chars = len(open(train_out, encoding='utf-8').read()) +val_chars = len(open(val_out, encoding='utf-8').read()) +print(f'train: {len(train_lines)} samples, {train_chars:,} chars -> ppl-eval-train.txt') +print(f'val: {len(val_lines)} samples, {val_chars:,} chars -> ppl-eval-val.txt') diff --git a/tools/moe-pruning/gguf_prune.py b/tools/moe-pruning/gguf_prune.py new file mode 100644 index 0000000000..df3e638ab4 --- /dev/null +++ b/tools/moe-pruning/gguf_prune.py @@ -0,0 +1,260 @@ +""" +gguf-prune: REAP-based expert pruning directly on a GGUF file. + +Slices the expert dimension of the four stacked MoE weight tensors per layer: + blk.{il}.ffn_up_exps [n_embd, intermediate, n_experts] + blk.{il}.ffn_down_exps [intermediate, n_embd, n_experts] + blk.{il}.ffn_gate_inp [n_embd, n_experts] + blk.{il}.ffn_exp_probs_b [n_experts] (score-correction bias, if present) + +Quantized blocks (Q4_K, Q6_K, …) are preserved as raw bytes — slicing the +expert axis (last dim) is safe because each expert is independently quantised +in ggml, so dropping experts = dropping whole quantisation blocks. + +Metadata patched: + {arch}.expert_count → keep_n + (expert_used_count = top-k routing k, NOT touched) + +Usage: + # keep top 20% of experts (26/128) per MoE layer + python gguf_prune.py \\ + --input nemotron.gguf \\ + --stats expert_stats.json \\ + --output nemotron-pruned.gguf \\ + --keep_ratio 0.20 + + # or keep an absolute number + python gguf_prune.py \\ + --input nemotron.gguf \\ + --stats expert_stats.json \\ + --output nemotron-pruned.gguf \\ + --keep_n 32 +""" + +from __future__ import annotations + +import argparse +import json +import re +from pathlib import Path + +import numpy as np +from gguf import GGUFReader, GGUFWriter, GGUFValueType + + +# ── Constants ───────────────────────────────────────────────────────────────── + +# Base tensor names that carry the expert dimension (last axis in ggml layout). +# Some GGUFs append parameter tails like ".weight" / ".bias". +EXPERT_BASE_SUFFIXES = { + "ffn_up_exps", + "ffn_down_exps", + "ffn_gate_inp", +} + + +def is_expert_suffix(suffix: str) -> bool: + """Return True if a tensor suffix is one of the MoE expert tensors to prune.""" + if suffix in ("ffn_exp_probs_b", "exp_probs_b", "exp_probs_b.bias"): + return True + return any(suffix == base or suffix.startswith(base + ".") for base in EXPERT_BASE_SUFFIXES) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def layer_and_suffix(name: str) -> tuple[int, str] | tuple[None, None]: + m = re.match(r"blk\.(\d+)\.(.+)$", name) + if m: + return int(m.group(1)), m.group(2) + return None, None + + +def pick_experts(layer_stats: dict, keep_n: int) -> list[int]: + """ + Return sorted indices of the top `keep_n` experts by REAP score. + Falls back to 'importance_score' (weighted frequency) if 'reap' absent. + """ + if "reap" in layer_stats: + scores = np.array(layer_stats["reap"], dtype=np.float64) + elif "importance_score" in layer_stats: + scores = np.array(layer_stats["importance_score"], dtype=np.float64) + else: + raise KeyError( + "Layer stats has neither 'reap' nor 'importance_score'. " + "Run expert-profile / nemotron_reap.py profile first." + ) + return sorted(np.argsort(scores)[-keep_n:].tolist()) + + +def slice_expert_axis(data: np.ndarray, keep: list[int]) -> np.ndarray: + """ + Slice the expert axis of reader tensor data keeping only `keep` indices. + + GGUFReader reshapes tensors to NumPy with reversed ggml dims, so for MoE + tensors where experts are the last ggml dim, expert is axis 0 in `data`. + This also preserves quantized row-byte alignment (axis -1 is byte-packed + rows for quantized tensors and must not be sliced for expert pruning). + """ + return np.take(data, keep, axis=0) + + +def copy_field(writer: GGUFWriter, field, reader: GGUFReader) -> bool: + """Copy a single metadata field to writer. Returns False if skipped.""" + key = field.name + val_type = field.types[0] + part = field.parts[-1] + + if val_type == GGUFValueType.STRING: + # Preserve raw bytes: GGUF metadata can contain non-UTF8 strings. + writer.add_key_value(key, bytes(part), GGUFValueType.STRING) + elif val_type == GGUFValueType.UINT8: + writer.add_uint8(key, int(part[0])) + elif val_type == GGUFValueType.INT8: + writer.add_int8(key, int(part[0])) + elif val_type == GGUFValueType.UINT16: + writer.add_uint16(key, int(part[0])) + elif val_type == GGUFValueType.INT16: + writer.add_int16(key, int(part[0])) + elif val_type == GGUFValueType.UINT32: + writer.add_uint32(key, int(part[0])) + elif val_type == GGUFValueType.INT32: + writer.add_int32(key, int(part[0])) + elif val_type == GGUFValueType.FLOAT32: + writer.add_float32(key, float(part[0])) + elif val_type == GGUFValueType.UINT64: + writer.add_uint64(key, int(part[0])) + elif val_type == GGUFValueType.INT64: + writer.add_int64(key, int(part[0])) + elif val_type == GGUFValueType.FLOAT64: + writer.add_float64(key, float(part[0])) + elif val_type == GGUFValueType.BOOL: + writer.add_bool(key, bool(part[0])) + elif val_type == GGUFValueType.ARRAY: + elem_type = field.types[1] + if elem_type == GGUFValueType.STRING: + # ReaderField.data stores indices of ARRAY payload items; for + # STRING arrays this points at each string byte payload. + vals = [bytes(field.parts[idx]) for idx in field.data] + writer.add_key_value(key, vals, GGUFValueType.ARRAY, sub_type=GGUFValueType.STRING) + else: + # ReaderField.data stores part-indices, not payload values. + vals = field.contents() + if not isinstance(vals, list): + print(f" WARNING: skipping array field {key!r} (unexpected non-list contents)") + return False + writer.add_array(key, vals) + else: + print(f" WARNING: skipping field {key!r} (unsupported type {val_type})") + return False + return True + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def main(): + ap = argparse.ArgumentParser(description="REAP expert pruning on a GGUF file") + ap.add_argument("--input", required=True, help="Input .gguf path") + ap.add_argument("--stats", required=True, help="expert_stats.json from expert-profile") + ap.add_argument("--output", required=True, help="Output .gguf path") + ap.add_argument("--keep_ratio", type=float, default=None, help="Fraction to keep, e.g. 0.20") + ap.add_argument("--keep_n", type=int, default=None, help="Absolute count to keep, e.g. 32") + ap.add_argument("--n_experts", type=int, default=128, help="Experts per MoE layer in source model") + args = ap.parse_args() + + if args.keep_ratio is None and args.keep_n is None: + ap.error("Provide --keep_ratio or --keep_n") + if args.keep_ratio is not None and args.keep_n is not None: + ap.error("Provide --keep_ratio OR --keep_n, not both") + + keep_n = args.keep_n if args.keep_n is not None else max(1, int(args.n_experts * args.keep_ratio)) + print(f"[gguf-prune] keeping {keep_n}/{args.n_experts} experts per MoE layer") + + # ── Load stats ───────────────────────────────────────────────────────────── + with open(args.stats) as f: + stats = {int(k): v for k, v in json.load(f).items()} + print(f"[gguf-prune] stats loaded for {len(stats)} MoE layers") + + # ── Open source GGUF ─────────────────────────────────────────────────────── + print(f"[gguf-prune] reading {args.input}") + reader = GGUFReader(args.input, mode="r") + + arch_field = reader.get_field("general.architecture") + arch = str(bytes(arch_field.parts[-1]), "utf-8") if arch_field else "nemotron_h_moe" + print(f"[gguf-prune] arch {arch}") + + expert_count_key = f"{arch}.expert_count" + + # ── Compute kept indices per layer ───────────────────────────────────────── + kept: dict[int, list[int]] = {} + for tensor in reader.tensors: + il, suffix = layer_and_suffix(tensor.name) + if il is None or suffix is None or not is_expert_suffix(suffix): + continue + if il in kept: + continue # already computed for this layer + if il not in stats: + print(f" Layer {il:3d}: no stats — keeping ALL {args.n_experts} experts") + kept[il] = list(range(args.n_experts)) + else: + kept[il] = pick_experts(stats[il], keep_n) + never = stats[il].get("never_activated", "?") + crit = "reap" if "reap" in stats[il] else "importance_score" + print(f" Layer {il:3d}: keep {kept[il][:4]}… never_activated={never} criterion={crit}") + + # ── Build output GGUF ────────────────────────────────────────────────────── + print(f"\n[gguf-prune] writing {args.output}") + writer = GGUFWriter(args.output, arch=arch) + + # --- metadata: copy all fields, replace expert_count --- + for field in reader.fields.values(): + # Reader exposes synthetic header fields (GGUF.*) that are not KV + # metadata and must not be copied back as normal keys. + if field.name.startswith("GGUF."): + continue + # Writer already sets general.architecture from ctor; avoid duplicate warning. + if field.name in (expert_count_key, "general.architecture"): + continue # replaced below + copy_field(writer, field, reader) + + writer.add_expert_count(keep_n) + print(f"[gguf-prune] patched {expert_count_key} → {keep_n}") + + # --- tensors --- + n_pruned = 0 + for tensor in reader.tensors: + il, suffix = layer_and_suffix(tensor.name) + is_expert = il is not None and suffix is not None and is_expert_suffix(suffix) + + if is_expert: + assert il is not None + k = kept[il] + data = slice_expert_axis(tensor.data, k) + writer.add_tensor( + tensor.name, + data, + raw_dtype=tensor.tensor_type, + ) + n_pruned += 1 + else: + writer.add_tensor( + tensor.name, + tensor.data, + raw_dtype=tensor.tensor_type, + ) + + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_tensors_to_file(progress=True) + writer.close() + + out = Path(args.output) + size_gb = out.stat().st_size / 1024**3 + print(f"\n[gguf-prune] done") + print(f" Expert tensors sliced : {n_pruned}") + print(f" MoE layers pruned : {len(kept)}") + print(f" Experts per layer : {keep_n}/{args.n_experts}") + print(f" Output size : {size_gb:.2f} GB → {out}") + + +if __name__ == "__main__": + main() diff --git a/tools/moe-pruning/requirements.txt b/tools/moe-pruning/requirements.txt new file mode 100644 index 0000000000..6b21f239f4 --- /dev/null +++ b/tools/moe-pruning/requirements.txt @@ -0,0 +1 @@ +gguf diff --git a/tools/moe-pruning/sample_calibration.jsonl b/tools/moe-pruning/sample_calibration.jsonl new file mode 100644 index 0000000000..44e2627f97 --- /dev/null +++ b/tools/moe-pruning/sample_calibration.jsonl @@ -0,0 +1,8 @@ +{"prompt": "What is the difference between a transformer and an RNN?", "response": "Transformers process all tokens in parallel using self-attention, while RNNs process tokens sequentially and maintain a hidden state. Transformers scale better and handle long-range dependencies more effectively, but require more memory for the attention matrix."} +{"prompt": "Explain the concept of weight quantization in neural networks.", "response": "Weight quantization reduces model size and inference latency by representing parameters in lower precision formats such as INT8 or INT4 instead of FP32. This trades a small amount of accuracy for significant gains in memory efficiency and compute throughput."} +{"prompt": "What is mixture of experts (MoE)?", "response": "MoE is a neural network architecture where the feed-forward layers are replaced by multiple expert sub-networks and a routing mechanism. For each token, only a small subset of experts is activated, allowing the model to have a large parameter count while keeping inference compute constant."} +{"prompt": "Describe the attention mechanism.", "response": "Attention computes a weighted sum of value vectors, where the weights are determined by the compatibility between query and key vectors. Scaled dot-product attention computes scores as Q*K^T/sqrt(d_k), applies softmax to get weights, then multiplies by V."} +{"prompt": "What is GGUF and how does it differ from GGML?", "response": "GGUF is the successor to the GGML file format for storing quantized models. It supports arbitrary key-value metadata, is extensible without breaking backward compatibility, and encodes tensor names and shapes explicitly, making it more robust than the original GGML format."} +{"prompt": "How does LoRA work?", "response": "LoRA (Low-Rank Adaptation) injects trainable rank-decomposition matrices A and B into frozen weight layers. The adapted weight is W + alpha/r * B*A. Since rank r is much smaller than the weight dimensions, only a tiny fraction of parameters are trained."} +{"prompt": "What is perplexity in language modeling?", "response": "Perplexity measures how well a language model predicts a sample text. It is the exponentiated average negative log-likelihood per token: PPL = exp(-1/N * sum log P(token_i)). Lower perplexity indicates a better fit to the data."} +{"prompt": "Explain rotary position embeddings (RoPE).", "response": "RoPE encodes position by rotating query and key vectors in 2D subspaces using a position-dependent rotation matrix. This makes the dot product between Q and K depend only on their relative position, enabling the model to generalise to sequence lengths longer than those seen during training."}