605 lines
23 KiB
Python
605 lines
23 KiB
Python
#!/usr/bin/env python3
|
||
"""Deep analysis of WHY ffn_down is hard to quantize.
|
||
Compares structural properties of all weight and activation tensors.
|
||
"""
|
||
|
||
import numpy as np
|
||
import struct
|
||
import sys
|
||
import os
|
||
|
||
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "data")
|
||
|
||
|
||
def load_f32_tensor(name):
|
||
path = os.path.join(DATA_DIR, name)
|
||
with open(path, "rb") as f:
|
||
nrow, ncol = struct.unpack("qq", f.read(16))
|
||
data = np.frombuffer(f.read(), dtype=np.float32)
|
||
assert len(data) == nrow * ncol, f"Expected {nrow * ncol}, got {len(data)}"
|
||
return data.reshape(nrow, ncol)
|
||
|
||
|
||
def stats(label, arr):
|
||
"""Print comprehensive statistics for a flat array."""
|
||
a = arr.ravel()
|
||
print(f" {label}:")
|
||
print(f" shape={arr.shape}, n={len(a)}")
|
||
print(f" mean={a.mean():.6f}, std={a.std():.6f}")
|
||
print(f" min={a.min():.6f}, max={a.max():.6f}")
|
||
print(f" median={np.median(a):.6f}")
|
||
print(
|
||
f" |mean|/std = {abs(a.mean()) / (a.std() + 1e-10):.4f} (offset-to-spread ratio)"
|
||
)
|
||
# Kurtosis (excess) - how heavy-tailed vs Gaussian
|
||
kurt = np.mean(((a - a.mean()) / (a.std() + 1e-10)) ** 4) - 3.0
|
||
# Skewness
|
||
skew = np.mean(((a - a.mean()) / (a.std() + 1e-10)) ** 3)
|
||
print(f" skewness={skew:.4f}, excess_kurtosis={kurt:.4f}")
|
||
# Percentile ranges
|
||
pcts = np.percentile(a, [0.1, 1, 5, 25, 50, 75, 95, 99, 99.9])
|
||
print(
|
||
f" percentiles: 0.1%={pcts[0]:.4f}, 1%={pcts[1]:.4f}, 5%={pcts[2]:.4f}, "
|
||
f"25%={pcts[3]:.4f}, 50%={pcts[4]:.4f}, 75%={pcts[5]:.4f}, "
|
||
f"95%={pcts[6]:.4f}, 99%={pcts[7]:.4f}, 99.9%={pcts[8]:.4f}"
|
||
)
|
||
# Sparsity
|
||
near_zero = np.sum(np.abs(a) < 0.001 * a.std()) / len(a)
|
||
print(f" fraction |x| < 0.001*std: {near_zero:.4f}")
|
||
return {
|
||
"mean": a.mean(),
|
||
"std": a.std(),
|
||
"skew": skew,
|
||
"kurt": kurt,
|
||
"min": a.min(),
|
||
"max": a.max(),
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# 1. BASIC WEIGHT TENSOR COMPARISON
|
||
# ============================================================================
|
||
print("=" * 80)
|
||
print("SECTION 1: WEIGHT TENSOR GLOBAL STATISTICS")
|
||
print("=" * 80)
|
||
|
||
tensors = {
|
||
"ffn_gate": ("blk_0_ffn_gate_weight.f32bin", "9728x2560 (wide→narrow proj)"),
|
||
"ffn_up": ("blk_0_ffn_up_weight.f32bin", "9728x2560 (wide→narrow proj)"),
|
||
"ffn_down": ("blk_0_ffn_down_weight.f32bin", "2560x9728 (narrow→wide proj)"),
|
||
"attn_q": ("blk_0_attn_q_weight.f32bin", "4096x2560"),
|
||
"attn_k": ("blk_0_attn_k_weight.f32bin", "1024x2560"),
|
||
"attn_v": ("blk_0_attn_v_weight.f32bin", "1024x2560"),
|
||
"attn_out": ("blk_0_attn_output_weight.f32bin", "2560x4096"),
|
||
}
|
||
|
||
weight_data = {}
|
||
for name, (fname, desc) in tensors.items():
|
||
try:
|
||
W = load_f32_tensor(fname)
|
||
print(f"\n{'─' * 70}")
|
||
print(f" {name} [{desc}] — file: {fname}")
|
||
weight_data[name] = W
|
||
stats(name, W)
|
||
except Exception as e:
|
||
print(f" {name}: SKIP ({e})")
|
||
|
||
# ============================================================================
|
||
# 2. ROW-LEVEL STATISTICS (each row is a neuron output)
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 2: ROW-LEVEL VARIABILITY (per-neuron weight statistics)")
|
||
print("=" * 80)
|
||
print(" Each row of the weight matrix produces one output dimension.")
|
||
print(" High row-to-row variability in mean/std means the quantizer")
|
||
print(" must handle very different distributions across rows.\n")
|
||
|
||
for name, W in weight_data.items():
|
||
row_means = W.mean(axis=1)
|
||
row_stds = W.std(axis=1)
|
||
row_ranges = W.max(axis=1) - W.min(axis=1)
|
||
|
||
print(f"\n {name} ({W.shape[0]} rows × {W.shape[1]} cols):")
|
||
print(
|
||
f" Row means: mean={row_means.mean():.6f}, std={row_means.std():.6f}, "
|
||
f"range=[{row_means.min():.6f}, {row_means.max():.6f}]"
|
||
)
|
||
print(
|
||
f" Row stds: mean={row_stds.mean():.6f}, std={row_stds.std():.6f}, "
|
||
f"range=[{row_stds.min():.6f}, {row_stds.max():.6f}]"
|
||
)
|
||
print(f" Row ranges: mean={row_ranges.mean():.6f}, std={row_ranges.std():.6f}")
|
||
print(
|
||
f" RowMeans CV (std/mean): {row_means.std() / (abs(row_means.mean()) + 1e-10):.4f}"
|
||
)
|
||
print(f" RowStds CV: {row_stds.std() / (row_stds.mean() + 1e-10):.4f}")
|
||
|
||
# ============================================================================
|
||
# 3. GROUP-LEVEL ANALYSIS (16-element groups, like Q2_K)
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 3: GROUP-LEVEL ANALYSIS (16-element groups)")
|
||
print("=" * 80)
|
||
print(" Quantization works on 16-element groups. Key question:")
|
||
print(" How much does each group need its own OFFSET (dmin)?\n")
|
||
|
||
GS = 16
|
||
|
||
for name, W in weight_data.items():
|
||
# Look at first 256 rows for speed
|
||
nr = min(W.shape[0], 256)
|
||
nc = W.shape[1]
|
||
|
||
group_means = []
|
||
group_stds = []
|
||
group_ranges = []
|
||
group_offsets = [] # |mean| / range — how important is the offset
|
||
|
||
for r in range(nr):
|
||
for g_start in range(0, nc, GS):
|
||
g = W[r, g_start : g_start + GS]
|
||
gm = g.mean()
|
||
gs = g.std()
|
||
gr = g.max() - g.min()
|
||
gmin = g.min()
|
||
|
||
group_means.append(gm)
|
||
group_stds.append(gs)
|
||
group_ranges.append(gr)
|
||
# Offset importance: how large is the group mean relative to its range?
|
||
# If this is high, offset (dmin) matters a lot
|
||
if gr > 1e-10:
|
||
group_offsets.append(abs(gm) / gr)
|
||
else:
|
||
group_offsets.append(0)
|
||
|
||
gm = np.array(group_means)
|
||
gs = np.array(group_stds)
|
||
gr = np.array(group_ranges)
|
||
go = np.array(group_offsets)
|
||
|
||
print(f"\n {name} ({len(group_means)} groups):")
|
||
print(
|
||
f" Group mean: mean={gm.mean():.6f}, std={gm.std():.6f}, "
|
||
f"range=[{gm.min():.6f}, {gm.max():.6f}]"
|
||
)
|
||
print(f" Group std: mean={gs.mean():.6f}, std={gs.std():.6f}")
|
||
print(f" Group range: mean={gr.mean():.6f}, std={gr.std():.6f}")
|
||
print(f" *** OFFSET IMPORTANCE (|group_mean| / range) ***")
|
||
print(
|
||
f" mean={go.mean():.4f}, median={np.median(go):.4f}, "
|
||
f"p90={np.percentile(go, 90):.4f}, max={go.max():.4f}"
|
||
)
|
||
print(f" fraction with offset > 0.1: {np.mean(go > 0.1):.3f}")
|
||
print(f" fraction with offset > 0.2: {np.mean(go > 0.2):.3f}")
|
||
print(f" fraction with offset > 0.3: {np.mean(go > 0.3):.3f}")
|
||
|
||
# How well does zeroing the min (Q2_K style, clamping min to 0) work?
|
||
# vs keeping the actual min
|
||
mse_no_offset = 0 # Assume uniform 4 levels [0,1,2,3] * scale
|
||
mse_with_offset = 0 # Assume uniform 4 levels [0,1,2,3] * scale + offset
|
||
|
||
for r in range(nr):
|
||
for g_start in range(0, nc, GS):
|
||
g = W[r, g_start : g_start + GS]
|
||
gmin = g.min()
|
||
gmax = g.max()
|
||
gr = gmax - gmin
|
||
if gr < 1e-10:
|
||
continue
|
||
|
||
# No offset: clamp min to 0, scale = max/3
|
||
if gmin > 0:
|
||
scale_no = gmax / 3.0
|
||
min_no = 0
|
||
else:
|
||
scale_no = gmax / 3.0
|
||
min_no = 0 # lose the negative offset
|
||
# Actually use (gmax - 0)/3 but we're clamping gmin to 0
|
||
|
||
# Better: use actual min/max
|
||
scale_w = gr / 3.0
|
||
min_w = gmin
|
||
|
||
for val in g:
|
||
# No offset quantization
|
||
norm_no = val / (scale_no + 1e-10)
|
||
idx_no = max(0, min(3, int(round(norm_no))))
|
||
recon_no = scale_no * idx_no
|
||
mse_no_offset += (val - recon_no) ** 2
|
||
|
||
# With offset quantization
|
||
norm_w = (val - min_w) / (scale_w + 1e-10)
|
||
idx_w = max(0, min(3, int(round(norm_w))))
|
||
recon_w = min_w + scale_w * idx_w
|
||
mse_with_offset += (val - recon_w) ** 2
|
||
|
||
total_elements = nr * nc
|
||
rmse_no = np.sqrt(mse_no_offset / total_elements)
|
||
rmse_w = np.sqrt(mse_with_offset / total_elements)
|
||
improvement = (rmse_no - rmse_w) / rmse_no * 100
|
||
print(f" Quant RMSE (no offset): {rmse_no:.6f}")
|
||
print(f" Quant RMSE (with offset): {rmse_w:.6f}")
|
||
print(f" Offset benefit: {improvement:.1f}% RMSE reduction")
|
||
|
||
# ============================================================================
|
||
# 4. ACTIVATION ANALYSIS
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 4: ACTIVATION DISTRIBUTION COMPARISON")
|
||
print("=" * 80)
|
||
|
||
activations = {
|
||
"ffn_input (gate/up)": "act_blk0_ffn_input.f32bin",
|
||
"ffn_down_input (swiglu)": "act_blk0_ffn_down_input.f32bin",
|
||
"attn_input (q/k/v)": "act_blk0_attn_input.f32bin",
|
||
"attn_output_input": "act_blk0_attn_output_input.f32bin",
|
||
}
|
||
|
||
act_data = {}
|
||
for name, fname in activations.items():
|
||
try:
|
||
A = load_f32_tensor(fname)
|
||
act_data[name] = A
|
||
print(f"\n{'─' * 70}")
|
||
print(f" {name} — {fname}")
|
||
stats(name, A)
|
||
except Exception as e:
|
||
print(f" {name}: SKIP ({e})")
|
||
|
||
# ============================================================================
|
||
# 5. THE CRITICAL QUESTION: PER-DIMENSION ACTIVATION MAGNITUDE
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 5: PER-DIMENSION ACTIVATION POWER (per-column RMS)")
|
||
print("=" * 80)
|
||
print(" If activation dimensions have very different magnitudes,")
|
||
print(" the quantization error in each weight dimension is weighted differently.")
|
||
print(" Dimensions with high activation power amplify weight errors.\n")
|
||
|
||
for name, A in act_data.items():
|
||
col_rms = np.sqrt(np.mean(A**2, axis=0)) # RMS per column (dimension)
|
||
print(f"\n {name} ({A.shape[1]} dimensions):")
|
||
print(f" Col RMS: mean={col_rms.mean():.6f}, std={col_rms.std():.6f}")
|
||
print(f" Col RMS range: [{col_rms.min():.6f}, {col_rms.max():.6f}]")
|
||
print(f" Col RMS CV (std/mean): {col_rms.std() / (col_rms.mean() + 1e-10):.4f}")
|
||
print(f" Max/Min ratio: {col_rms.max() / (col_rms.min() + 1e-10):.1f}x")
|
||
|
||
# Top 10 and bottom 10 dimensions by power
|
||
top10 = np.argsort(col_rms)[-10:][::-1]
|
||
bot10 = np.argsort(col_rms)[:10]
|
||
print(
|
||
f" Top-10 dims by RMS: {[(int(d), f'{col_rms[d]:.4f}') for d in top10[:5]]}..."
|
||
)
|
||
print(
|
||
f" Bot-10 dims by RMS: {[(int(d), f'{col_rms[d]:.4f}') for d in bot10[:5]]}..."
|
||
)
|
||
|
||
# How much do the top 10% of dimensions contribute to total power?
|
||
total_power = np.sum(col_rms**2)
|
||
sorted_power = np.sort(col_rms**2)[::-1]
|
||
top10pct = int(len(col_rms) * 0.1)
|
||
top10pct_power = np.sum(sorted_power[:top10pct])
|
||
top1pct = max(1, int(len(col_rms) * 0.01))
|
||
top1pct_power = np.sum(sorted_power[:top1pct])
|
||
print(
|
||
f" Top 10% of dims contribute {top10pct_power / total_power * 100:.1f}% of total power"
|
||
)
|
||
print(
|
||
f" Top 1% of dims contribute {top1pct_power / total_power * 100:.1f}% of total power"
|
||
)
|
||
|
||
# ============================================================================
|
||
# 6. CROSS-CORRELATION: WEIGHT ERROR × ACTIVATION POWER
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 6: WHERE DO WEIGHT ERRORS MEET HIGH ACTIVATION POWER?")
|
||
print("=" * 80)
|
||
print(" For each weight dimension, compute: activation_rms[dim] × weight_error[dim]")
|
||
print(" This tells us which dimensions contribute most to matmul error.\n")
|
||
|
||
# Focus on ffn_down vs ffn_gate for comparison
|
||
focus = [
|
||
("ffn_down", "blk_0_ffn_down_weight.f32bin", "act_blk0_ffn_down_input.f32bin"),
|
||
("ffn_gate", "blk_0_ffn_gate_weight.f32bin", "act_blk0_ffn_input.f32bin"),
|
||
("ffn_up", "blk_0_ffn_up_weight.f32bin", "act_blk0_ffn_input.f32bin"),
|
||
("attn_q", "blk_0_attn_q_weight.f32bin", "act_blk0_attn_input.f32bin"),
|
||
]
|
||
|
||
for name, wfile, afile in focus:
|
||
W = load_f32_tensor(wfile)
|
||
A = load_f32_tensor(afile)
|
||
|
||
if W.shape[1] != A.shape[1]:
|
||
print(f" {name}: dim mismatch W={W.shape[1]} vs A={A.shape[1]}, SKIP")
|
||
continue
|
||
|
||
nc = W.shape[1]
|
||
|
||
# Per-column activation RMS
|
||
act_rms = np.sqrt(np.mean(A**2, axis=0))
|
||
|
||
# Per-column weight std and range (how "hard" to quantize)
|
||
w_std = W.std(axis=0)
|
||
w_range = W.max(axis=0) - W.min(axis=0)
|
||
|
||
# Per-column weight kurtosis (heavy tails = harder to quantize)
|
||
w_kurt = (
|
||
np.mean(((W - W.mean(axis=0)) / (W.std(axis=0) + 1e-10)) ** 4, axis=0) - 3.0
|
||
)
|
||
|
||
# Weight error proxy: with 2-bit uniform quant on 16-element groups
|
||
# Higher variance columns → more error
|
||
nr = min(W.shape[0], 256)
|
||
|
||
# Simple Q2_K-style error estimate per dimension:
|
||
# For each group of 16 in the column direction, quantize and measure error
|
||
dim_mse = np.zeros(nc)
|
||
for g_start in range(0, nc, GS):
|
||
g_end = min(g_start + GS, nc)
|
||
for r in range(nr):
|
||
g = W[r, g_start:g_end]
|
||
gmin = min(g.min(), 0) # Q2_K clamps min to ≤0
|
||
gmax = g.max()
|
||
gr = gmax - gmin
|
||
if gr < 1e-10:
|
||
continue
|
||
scale = gr / 3.0
|
||
for i, val in enumerate(g):
|
||
norm = (val - gmin) / scale
|
||
idx = max(0, min(3, int(round(norm))))
|
||
recon = gmin + scale * idx
|
||
dim_mse[g_start + i] += (val - recon) ** 2
|
||
|
||
dim_rmse = np.sqrt(dim_mse / nr)
|
||
|
||
# The key metric: dimension-level contribution to matmul error
|
||
# matmul_error_contribution[d] ≈ act_rms[d] * weight_rmse[d]
|
||
matmul_contrib = act_rms * dim_rmse
|
||
|
||
print(f"\n {name} ({nc} dimensions):")
|
||
print(
|
||
f" act_rms: mean={act_rms.mean():.4f}, CV={act_rms.std() / act_rms.mean():.4f}"
|
||
)
|
||
print(
|
||
f" w_rmse: mean={dim_rmse.mean():.6f}, CV={dim_rmse.std() / (dim_rmse.mean() + 1e-10):.4f}"
|
||
)
|
||
print(
|
||
f" matmul_contrib: mean={matmul_contrib.mean():.6f}, "
|
||
f"std={matmul_contrib.std():.6f}"
|
||
)
|
||
|
||
# Correlation between activation power and weight error
|
||
corr = np.corrcoef(act_rms, dim_rmse)[0, 1]
|
||
print(f" CORRELATION act_rms ↔ weight_rmse: {corr:.4f}")
|
||
print(f" (>0 means high-power dims are also hard to quantize — BAD)")
|
||
|
||
# Top contributors to matmul error
|
||
top_dims = np.argsort(matmul_contrib)[-20:][::-1]
|
||
print(f" Top-5 error-contributing dimensions:")
|
||
for d in top_dims[:5]:
|
||
print(
|
||
f" dim {d}: act_rms={act_rms[d]:.4f}, w_rmse={dim_rmse[d]:.6f}, "
|
||
f"contrib={matmul_contrib[d]:.6f}, w_std={w_std[d]:.6f}, w_kurt={w_kurt[d]:.2f}"
|
||
)
|
||
|
||
# Distribution of matmul contributions
|
||
total_contrib = matmul_contrib.sum()
|
||
sorted_contrib = np.sort(matmul_contrib)[::-1]
|
||
for pct in [0.01, 0.05, 0.10, 0.25]:
|
||
n = max(1, int(nc * pct))
|
||
print(
|
||
f" Top {pct * 100:.0f}% dims: {sorted_contrib[:n].sum() / total_contrib * 100:.1f}% "
|
||
f"of total matmul error"
|
||
)
|
||
|
||
# ============================================================================
|
||
# 7. THE STRUCTURAL ASYMMETRY: COLUMN DIRECTION GROUP ANALYSIS
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 7: STRUCTURAL ASYMMETRY — COLUMN vs ROW GROUPING")
|
||
print("=" * 80)
|
||
print(" Quantization groups along the ROW (inner dim). For ffn_down,")
|
||
print(" each row has 9728 elements (38 groups of 256).")
|
||
print(" For ffn_gate, each row has 2560 elements (10 groups of 256).")
|
||
print(" More groups = more metadata (scales/offsets) relative to data bits.\n")
|
||
|
||
for name, wfile, afile in focus:
|
||
W = load_f32_tensor(wfile)
|
||
nc = W.shape[1]
|
||
n_groups_per_row = nc // 256 # super-blocks per row
|
||
|
||
print(f"\n {name}: {nc} cols → {n_groups_per_row} super-blocks per row")
|
||
print(f" Groups per row: {nc // GS} (16-element groups)")
|
||
print(
|
||
f" With Q2_K (2.625 bpw): {n_groups_per_row * 2} scale+offset bytes per row"
|
||
)
|
||
|
||
# How much do group means vary WITHIN a row?
|
||
nr = min(W.shape[0], 64)
|
||
intra_row_mean_var = []
|
||
for r in range(nr):
|
||
group_means = []
|
||
for g_start in range(0, nc, GS):
|
||
group_means.append(W[r, g_start : g_start + GS].mean())
|
||
group_means = np.array(group_means)
|
||
intra_row_mean_var.append(group_means.std())
|
||
|
||
print(
|
||
f" Intra-row group mean variability (avg across rows): "
|
||
f"mean={np.mean(intra_row_mean_var):.6f}"
|
||
)
|
||
|
||
# How much does the sign of group means vary?
|
||
pos_frac = 0
|
||
neg_frac = 0
|
||
total_groups = 0
|
||
for r in range(nr):
|
||
for g_start in range(0, nc, GS):
|
||
gm = W[r, g_start : g_start + GS].mean()
|
||
if gm > 0.001:
|
||
pos_frac += 1
|
||
elif gm < -0.001:
|
||
neg_frac += 1
|
||
total_groups += 1
|
||
print(
|
||
f" Group mean sign: {pos_frac / total_groups * 100:.1f}% positive, "
|
||
f"{neg_frac / total_groups * 100:.1f}% negative, "
|
||
f"{(1 - pos_frac / total_groups - neg_frac / total_groups) * 100:.1f}% near-zero"
|
||
)
|
||
|
||
# ============================================================================
|
||
# 8. THE SWIGLU EFFECT: WHY ffn_down INPUT IS SPECIAL
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 8: THE SWIGLU EFFECT — ffn_down ACTIVATION STRUCTURE")
|
||
print("=" * 80)
|
||
print(" ffn_down's activation is the SwiGLU output: silu(gate) * up")
|
||
print(" This creates a specific activation pattern that differs from")
|
||
print(" raw FFN input (RMSNorm output).\n")
|
||
|
||
if "ffn_input (gate/up)" in act_data and "ffn_down_input (swiglu)" in act_data:
|
||
A_in = act_data["ffn_input (gate/up)"]
|
||
A_swiglu = act_data["ffn_down_input (swiglu)"]
|
||
|
||
print(f" FFN input (RMSNorm output): {A_in.shape}")
|
||
print(f" SwiGLU output: {A_swiglu.shape}")
|
||
|
||
# Per-token analysis
|
||
for t in range(min(A_swiglu.shape[0], 3)):
|
||
tok_in = A_in[t]
|
||
tok_sw = A_swiglu[t]
|
||
print(f"\n Token {t}:")
|
||
print(
|
||
f" FFN input: mean={tok_in.mean():.6f}, std={tok_in.std():.6f}, "
|
||
f"|max|={np.abs(tok_in).max():.6f}"
|
||
)
|
||
print(
|
||
f" SwiGLU out: mean={tok_sw.mean():.6f}, std={tok_sw.std():.6f}, "
|
||
f"|max|={np.abs(tok_sw).max():.6f}"
|
||
)
|
||
|
||
# SwiGLU creates lots of near-zero values (silu suppresses negatives)
|
||
frac_nearzero_sw = np.mean(np.abs(tok_sw) < 0.01 * tok_sw.std())
|
||
frac_nearzero_in = np.mean(np.abs(tok_in) < 0.01 * tok_in.std())
|
||
print(
|
||
f" Near-zero fraction: FFN input={frac_nearzero_in:.3f}, "
|
||
f"SwiGLU={frac_nearzero_sw:.3f}"
|
||
)
|
||
|
||
# Sparsity pattern
|
||
frac_neg = np.mean(tok_sw < 0)
|
||
print(f" SwiGLU negative fraction: {frac_neg:.3f}")
|
||
|
||
# Dimension-level analysis of SwiGLU
|
||
print(f"\n Dimension-level SwiGLU properties:")
|
||
dim_mean_sw = A_swiglu.mean(axis=0)
|
||
dim_std_sw = A_swiglu.std(axis=0)
|
||
dim_sparsity = np.mean(A_swiglu < 0, axis=0) # fraction of tokens negative per dim
|
||
|
||
print(f" Dim mean range: [{dim_mean_sw.min():.6f}, {dim_mean_sw.max():.6f}]")
|
||
print(f" Dim std range: [{dim_std_sw.min():.6f}, {dim_std_sw.max():.6f}]")
|
||
print(
|
||
f" Dim negative fraction: mean={dim_sparsity.mean():.3f}, "
|
||
f"range=[{dim_sparsity.min():.3f}, {dim_sparsity.max():.3f}]"
|
||
)
|
||
|
||
# Highly sparse dimensions (mostly near-zero after SwiGLU)
|
||
high_sparsity = np.sum(dim_sparsity > 0.7)
|
||
low_sparsity = np.sum(dim_sparsity < 0.3)
|
||
print(f" Dims with >70% negative tokens: {high_sparsity}/{len(dim_sparsity)}")
|
||
print(f" Dims with <30% negative tokens: {low_sparsity}/{len(dim_sparsity)}")
|
||
|
||
# ============================================================================
|
||
# 9. QUANTIZATION NOISE × ACTIVATION POWER: THE MATMUL ERROR DECOMPOSITION
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("SECTION 9: MATMUL ERROR DECOMPOSITION")
|
||
print("=" * 80)
|
||
print(
|
||
" matmul_error ≈ sum over groups of (activation_power_in_group × "
|
||
"weight_mse_in_group)"
|
||
)
|
||
print(
|
||
" If activation power is concentrated in groups with high weight error, "
|
||
"matmul error explodes.\n"
|
||
)
|
||
|
||
# For ffn_down specifically, compare where activation power sits vs weight error
|
||
W_down = load_f32_tensor("blk_0_ffn_down_weight.f32bin")
|
||
A_swiglu = load_f32_tensor("act_blk0_ffn_down_input.f32bin")
|
||
|
||
W_gate = load_f32_tensor("blk_0_ffn_gate_weight.f32bin")
|
||
A_ffn_in = load_f32_tensor("act_blk0_ffn_input.f32bin")
|
||
|
||
for label, W, A in [("ffn_down", W_down, A_swiglu), ("ffn_gate", W_gate, A_ffn_in)]:
|
||
nc = W.shape[1]
|
||
nr = min(W.shape[0], 128)
|
||
|
||
# Compute per-superblock (256) activation power and weight error
|
||
n_sb = nc // 256
|
||
sb_act_power = np.zeros(n_sb)
|
||
sb_weight_mse = np.zeros(n_sb)
|
||
|
||
for sb in range(n_sb):
|
||
s = sb * 256
|
||
e = s + 256
|
||
# Activation power: mean squared activation in this region
|
||
sb_act_power[sb] = np.mean(A[:, s:e] ** 2)
|
||
|
||
# Weight MSE: Q2_K-style uniform quant error
|
||
mse = 0
|
||
count = 0
|
||
for r in range(nr):
|
||
for g in range(0, 256, GS):
|
||
gvals = W[r, s + g : s + g + GS]
|
||
gmin = min(gvals.min(), 0)
|
||
gmax = gvals.max()
|
||
gr = gmax - gmin
|
||
if gr < 1e-10:
|
||
continue
|
||
scale = gr / 3.0
|
||
for v in gvals:
|
||
norm = (v - gmin) / scale
|
||
idx = max(0, min(3, int(round(norm))))
|
||
recon = gmin + scale * idx
|
||
mse += (v - recon) ** 2
|
||
count += 1
|
||
sb_weight_mse[sb] = mse / max(count, 1)
|
||
|
||
# Correlation between activation power and weight error across super-blocks
|
||
valid = sb_act_power > 1e-10
|
||
if valid.sum() > 10:
|
||
corr = np.corrcoef(np.sqrt(sb_act_power[valid]), np.sqrt(sb_weight_mse[valid]))[
|
||
0, 1
|
||
]
|
||
else:
|
||
corr = 0
|
||
|
||
print(f"\n {label}:")
|
||
print(f" Super-blocks: {n_sb}")
|
||
print(
|
||
f" act_power: mean={sb_act_power.mean():.6f}, "
|
||
f"std={np.sqrt(sb_act_power.var()):.6f}, "
|
||
f"range=[{sb_act_power.min():.6f}, {sb_act_power.max():.6f}]"
|
||
)
|
||
print(
|
||
f" weight_mse: mean={sb_weight_mse.mean():.6f}, "
|
||
f"range=[{sb_weight_mse.min():.6f}, {sb_weight_mse.max():.6f}]"
|
||
)
|
||
print(f" CORRELATION (act_power ↔ weight_mse): {corr:.4f}")
|
||
|
||
# Show top-5 super-blocks by contribution to matmul error
|
||
contrib = sb_act_power * sb_weight_mse
|
||
top5 = np.argsort(contrib)[-5:][::-1]
|
||
print(f" Top-5 error-contributing super-blocks (of {n_sb}):")
|
||
for idx in top5:
|
||
print(
|
||
f" SB {idx * 256}-{(idx + 1) * 256 - 1}: act_power={sb_act_power[idx]:.6f}, "
|
||
f"weight_mse={sb_weight_mse[idx]:.6f}, contrib={contrib[idx]:.6f}"
|
||
)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("ANALYSIS COMPLETE")
|
||
print("=" * 80)
|