Comment consistencty pass and cleanup.
This commit is contained in:
parent
c2f2ff7814
commit
f603c036ec
|
|
@ -430,13 +430,16 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|||
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||
|
||||
// E8M0 shared exponent to float. Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h.
|
||||
// E8M0 shared exponent to float: returns 2^(x - 127).
|
||||
// Canonical source: ggml_mxfp_e8m0_to_fp32() in ggml-common.h.
|
||||
// Kept here because ggml-impl.h cannot depend on ggml-common.h IMPL section.
|
||||
// NaN (x == 0xFF) is not handled — callers guarantee valid exponents.
|
||||
static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
||||
uint32_t bits;
|
||||
if (x == 0) {
|
||||
bits = 0x00400000; // 2^(-127)
|
||||
bits = 0x00400000; // denorm: 0.5 * 2^(-126) = 2^(-127)
|
||||
} else {
|
||||
// normalized: exponent x placed in bits [30:23], mantissa = 0 → 2^(x-127)
|
||||
bits = (uint32_t) x << 23;
|
||||
}
|
||||
float result;
|
||||
|
|
@ -444,12 +447,16 @@ static inline float ggml_e8m0_to_fp32(uint8_t x) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// E8M0 to float/2. Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h.
|
||||
// E8M0 to float/2: returns 2^(x - 128). Equal to ggml_e8m0_to_fp32(x) / 2.
|
||||
// Useful with MXFP4 because the E2M1 kvalues table stores 2 * float_value.
|
||||
// Canonical source: ggml_mxfp_e8m0_to_fp32_half() in ggml-common.h.
|
||||
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
||||
uint32_t bits;
|
||||
if (x < 2) {
|
||||
// x=0 → 2^(-128), x=1 → 2^(-127): denormal bit patterns
|
||||
bits = 0x00200000 << x;
|
||||
} else {
|
||||
// 0.5 * 2^(x-127) = 2^(x-128): normalized with exponent (x-1)
|
||||
bits = (uint32_t)(x - 1) << 23;
|
||||
}
|
||||
float result;
|
||||
|
|
@ -460,37 +467,42 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
|||
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
||||
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
||||
|
||||
// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
|
||||
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)
|
||||
// UE4M3 (unsigned E4M3): 4 exponent bits (bias 7), 3 mantissa bits, no sign.
|
||||
// Range: [0, 448], with 0x7F = NaN treated as zero.
|
||||
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float).
|
||||
static inline float ggml_ue4m3_to_fp32(uint8_t x) {
|
||||
if (x == 0 || x == 0x7F) {
|
||||
return 0.0f;
|
||||
return 0.0f; // zero and NaN → 0
|
||||
}
|
||||
int exp = (x >> 3) & 0xF;
|
||||
int man = x & 0x7;
|
||||
float raw;
|
||||
if (exp == 0) {
|
||||
// subnormal: value = man * 2^(1 - bias - mantissa_bits) = man * 2^(-9)
|
||||
raw = ldexpf((float) man, -9);
|
||||
} else {
|
||||
// normalized: value = (1 + man/8) * 2^(exp - 7)
|
||||
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
|
||||
}
|
||||
return raw * 0.5f;
|
||||
}
|
||||
|
||||
// Float32 to UE4M3 with round-to-nearest-even.
|
||||
// Clamps to [0, 448]. Max representable: 0x7E = (1 + 7/8) * 2^8 = 448.
|
||||
static inline uint8_t ggml_fp32_to_ue4m3(float x) {
|
||||
if (!(x > 0.0f)) {
|
||||
return 0;
|
||||
return 0; // negative, zero, NaN → 0
|
||||
}
|
||||
if (x > 448.0f) {
|
||||
x = 448.0f;
|
||||
x = 448.0f; // clamp to max representable
|
||||
}
|
||||
uint32_t bits;
|
||||
memcpy(&bits, &x, 4);
|
||||
int fp32_exp = ((bits >> 23) & 0xFF) - 127;
|
||||
int fp32_man = (bits >> 20) & 0x7;
|
||||
int ue4m3_exp = fp32_exp + 7;
|
||||
int fp32_man = (bits >> 20) & 0x7; // top 3 mantissa bits
|
||||
int ue4m3_exp = fp32_exp + 7; // rebias: FP32 bias 127 → UE4M3 bias 7
|
||||
if (ue4m3_exp <= 0) {
|
||||
// subnormal: value = man * 2^-9, man = round(x * 2^9)
|
||||
// subnormal: value = man * 2^(-9), so man = round(x * 512)
|
||||
int man = (int) (x * 512.0f + 0.5f);
|
||||
if (man > 7) {
|
||||
man = 7;
|
||||
|
|
@ -501,15 +513,17 @@ static inline uint8_t ggml_fp32_to_ue4m3(float x) {
|
|||
return (uint8_t) man;
|
||||
}
|
||||
if (ue4m3_exp >= 15) {
|
||||
return 0x7E;
|
||||
return 0x7E; // max normal
|
||||
}
|
||||
// round-to-nearest using bit 19 (first bit below UE4M3 mantissa)
|
||||
int round_bit = (bits >> 19) & 1;
|
||||
int ue4m3_man = fp32_man + round_bit;
|
||||
if (ue4m3_man > 7) {
|
||||
// mantissa overflow → carry into exponent
|
||||
ue4m3_man = 0;
|
||||
ue4m3_exp++;
|
||||
if (ue4m3_exp >= 15) {
|
||||
return 0x7E;
|
||||
return 0x7E; // max normal
|
||||
}
|
||||
}
|
||||
return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);
|
||||
|
|
|
|||
|
|
@ -257,96 +257,31 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
|
|||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MXFP Element Conversion Functions
|
||||
// ============================================================================
|
||||
//
|
||||
// Reference implementations for OCP Microscaling (MX) format element types.
|
||||
// Spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
//
|
||||
// All converters use IEEE-754 bit manipulation via memcpy (C99 safe, no strict
|
||||
// aliasing issues). Quantization uses round-to-nearest-even (RNE) per MX spec.
|
||||
//
|
||||
// These functions are exposed in ggml-quants.h for use by CPU backends and tests.
|
||||
// GPU backends (CUDA, Vulkan, Metal) provide their own optimized versions using
|
||||
// hardware intrinsics (e.g., __nv_cvt_float_to_fp8, SIMD groups, LUT lookups).
|
||||
//
|
||||
// Key design decisions validated empirically on CUDA (Qwen3-Coder-30B-A3B):
|
||||
//
|
||||
// 1. SATURATION, NOT NaN PROPAGATION: FP8 E4M3 saturates to max (0x7E = 448)
|
||||
// rather than producing NaN. The single NaN encoding (0x7F) is avoided.
|
||||
// This matches the MX spec behavior and prevents NaN corruption in KV caches.
|
||||
//
|
||||
// 2. MX FP6 HAS NO NaN/Inf: Unlike IEEE-754, the MX spec defines exp=max as a
|
||||
// valid normal value for FP6 types. Dequantizers must NOT special-case it.
|
||||
//
|
||||
// 3. RNE ROUNDING IN SUBNORMALS: Both normal and subnormal paths use proper
|
||||
// round-to-nearest-even with sticky bit tracking. This was a P0 bug fix —
|
||||
// truncation caused measurable PPL regression.
|
||||
//
|
||||
// 4. E3M2 SUBNORMAL SCALE: mant * 2^(1-bias-m) = mant * 2^(-4) = mant/16.
|
||||
// NOT mant/4. This was a critical bug — the exponent bias and mantissa width
|
||||
// both affect the subnormal multiplier.
|
||||
//
|
||||
// ====================== MXFP element conversions (wrappers around ggml-common.h)
|
||||
|
||||
// FP8 E4M3: 1 sign, 4 exponent (bias 7), 3 mantissa
|
||||
// Max finite: 448 (exp=15, mant=6), NaN: exp=15, mant=7
|
||||
// Thin wrappers around canonical implementations in ggml-common.h.
|
||||
// Verified bit-for-bit identical by test-mxfp-converters.
|
||||
// FP8 E4M3: 1 sign, 4 exp (bias 7), 3 mantissa. Max finite: 448, NaN: 0x7F (saturated to max).
|
||||
float fp8_e4m3_to_float(uint8_t v) { return ggml_mxfp_fp8_e4m3_to_float(v); }
|
||||
uint8_t float_to_fp8_e4m3_rn(float x) { return ggml_mxfp_float_to_fp8_e4m3(x); }
|
||||
|
||||
// ============================================================================
|
||||
// MXFP Quantization Infrastructure
|
||||
// ============================================================================
|
||||
//
|
||||
// The MX format uses a shared E8M0 exponent per block of 32 elements. Choosing
|
||||
// the optimal exponent is critical for quantization quality.
|
||||
//
|
||||
// The OCP MX v1.0 spec (§5.3) specifies floor(log2(amax)) for the shared exponent.
|
||||
// We improve on this with an MSE-optimal ±1 search that tests 3 candidate exponents
|
||||
// {e-1, e, e+1} around round(log2(amax)) and picks whichever minimizes the total
|
||||
// round-trip quantization error for the block. This consistently improves perplexity
|
||||
// by 0.05-0.2 across all MX types versus floor-only or round-only approaches.
|
||||
//
|
||||
// The round(log2(amax)) base is computed via IEEE-754 integer bit extraction rather
|
||||
// than log2f(), avoiding GPU Special Function Unit (SFU) bottlenecks. The rounding
|
||||
// threshold 0x3504F3 is the fractional part of sqrt(2) in IEEE-754 mantissa bits:
|
||||
// if mantissa >= (sqrt(2)-1)*2^23 ≈ 0x3504F3, then log2(x) >= n+0.5, so round up.
|
||||
//
|
||||
// Each MX element type provides an mse_error function that computes the round-trip
|
||||
// quantization error for a single value at a given scale. The traits structure
|
||||
// encapsulates this per-type behavior.
|
||||
// ====================== MXFP quantization infrastructure
|
||||
//
|
||||
// MSE-optimal E8M0 shared exponent: tests ±R candidates around round(log2(amax))
|
||||
// and picks whichever minimizes total round-trip quantization error per block.
|
||||
// Improves on OCP MX v1.0 §5.3 floor(log2(amax)) by 0.05-0.2 PPL.
|
||||
// Ref: OCP MX v1.0 spec, Four Over Six (arXiv:2512.02010)
|
||||
|
||||
// Per-type traits for MSE-optimal E8M0 scale computation.
|
||||
// emax_offset: type-specific offset from E8M0 bias to type's max representable exponent
|
||||
// to_elem/to_float: element conversion function pointers (NULL for MXFP4 which uses LUT)
|
||||
// mse_error: round-trip error function for a single value at a given scale
|
||||
typedef struct {
|
||||
int emax_offset;
|
||||
int emax_offset; // type-specific offset to max representable exponent
|
||||
uint8_t (*to_elem)(float);
|
||||
float (*to_float)(uint8_t);
|
||||
float (*mse_error)(float val, float inv_scale, float scale);
|
||||
} mxfp_elem_traits_t;
|
||||
|
||||
// Forward declaration — defined after kvalues_mxfp4 lookup table section.
|
||||
static inline int best_index_mxfp4(float x, float e);
|
||||
|
||||
// MXFP4 E2M1 MSE error: decision boundary quantization with HALF scale factor.
|
||||
//
|
||||
// This CPU implementation uses the doubled int8 kvalues_mxfp4 LUT {0,1,2,3,4,6,8,12}
|
||||
// with GGML_E8M0_TO_FP32_HALF(e) = scale/2 for efficient nibble-indexed integer arithmetic.
|
||||
// The MSE interface passes GGML_E8M0_TO_FP32(e) as scale, so we halve it.
|
||||
//
|
||||
// Canonical E2M1 values are {0, 0.5, 1, 1.5, 2, 3, 4, 6} (kvalues_mxfp4_float in ggml-common.h).
|
||||
// Doubled boundaries {0.5, 1.5, 2.5, 3.5, 5, 7, 10} ÷ 2 = canonical {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5}.
|
||||
// Mathematically identical — the doubling is an implementation detail.
|
||||
// This is the Lloyd-Max quantizer for uniform input density.
|
||||
// MXFP4 MSE error using decision boundary quantization with half-scale
|
||||
// (kvalues_mxfp4 are doubled E2M1 values, so scale is halved to compensate)
|
||||
static float mse_error_mxfp4(float val, float inv_scale, float scale) {
|
||||
// Decision boundary quantization with direct reconstruction.
|
||||
// kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12}
|
||||
// Use inv_scale * 2 since MXFP4 scale includes 0.5x factor.
|
||||
const float d = scale * 0.5f;
|
||||
const float inv_d = (d > 0.0f) ? 1.0f / d : 0.0f;
|
||||
const float normalized = fabsf(val) * inv_d;
|
||||
|
|
@ -366,22 +301,7 @@ static float mse_error_mxfp4(float val, float inv_scale, float scale) {
|
|||
|
||||
static const mxfp_elem_traits_t mxfp4_traits = { MXFP4_E2M1_EMAX_OFFSET, NULL, NULL, mse_error_mxfp4 };
|
||||
|
||||
// MSE-optimal E8M0 shared exponent computation.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Find amax = max(|x[0..qk-1]|)
|
||||
// 2. Compute e_base = round(log2(amax)) - emax_offset + 127 via integer bit ops
|
||||
// 3. Test {e_base-R .. e_base+R}, pick the one minimizing total round-trip MSE
|
||||
// where R = MXFP_E8M0_MSE_RANGE (defined in ggml-common.h)
|
||||
//
|
||||
// The ±R search improves on the OCP spec's floor(log2(amax)). Wider search finds
|
||||
// better scales for blocks with non-uniform value distributions (especially FP4).
|
||||
// Cost is (2R+1) × qk roundtrip evaluations per block — negligible vs attention compute.
|
||||
//
|
||||
// Integer log2 avoids log2f() (SFU-dependent on GPU). The sqrt(2) rounding threshold
|
||||
// ensures we start from round() not floor().
|
||||
//
|
||||
// Ref: OCP MX v1.0 §5.3; Four Over Six (arXiv:2512.02010)
|
||||
// Find MSE-optimal E8M0 exponent by testing ±R candidates around round(log2(amax))
|
||||
static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_elem_traits_t * traits) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < qk; j++) {
|
||||
|
|
@ -392,7 +312,6 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
|
|||
|
||||
const int e_base = ggml_mxfp_e8m0_base_estimate(amax, traits->emax_offset);
|
||||
|
||||
// ±R MSE search: test 2R+1 candidates around e_base, pick lowest total MSE.
|
||||
int e_lo = e_base - MXFP_E8M0_MSE_RANGE;
|
||||
int e_hi = e_base + MXFP_E8M0_MSE_RANGE;
|
||||
if (e_lo < 1) e_lo = 1;
|
||||
|
|
@ -417,10 +336,8 @@ static inline uint8_t mxfp_compute_e8m0_mse(const float * x, int qk, const mxfp_
|
|||
return (uint8_t)best_e;
|
||||
}
|
||||
|
||||
// Decision boundary quantization for kvalues_mxfp4 {0,1,2,3,4,6,8,12}
|
||||
static inline int best_index_mxfp4(float x, float e) {
|
||||
// Decision boundary quantization: 7 comparisons instead of 16-element scan.
|
||||
// kvalues_mxfp4 positive sorted: {0, 1, 2, 3, 4, 6, 8, 12}
|
||||
// Decision boundaries (midpoints): {0.5, 1.5, 2.5, 3.5, 5, 7, 10}
|
||||
const float inv_e = (e > 0.0f) ? 1.0f / e : 0.0f;
|
||||
const float normalized = fabsf(x) * inv_e;
|
||||
int idx;
|
||||
|
|
@ -435,10 +352,6 @@ static inline int best_index_mxfp4(float x, float e) {
|
|||
return (x < 0.0f) ? (idx + 8) : idx;
|
||||
}
|
||||
|
||||
// FP4 E2M1: search-based quantization using best_index_mxfp4 lookup table.
|
||||
// Unlike FP6/FP8 which use direct float->element conversion, FP4 finds the
|
||||
// closest 4-bit value by minimizing reconstruction error against the lookup table.
|
||||
// Scale uses GGML_E8M0_TO_FP32_HALF (includes 0.5x factor for E2M1 mantissa range).
|
||||
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = QK_MXFP4;
|
||||
|
||||
|
|
@ -652,37 +565,13 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST
|
|||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hadamard Rotation (reference scalar implementation)
|
||||
// ============================================================================
|
||||
//
|
||||
// 32-element Walsh-Hadamard transform, applied to MX blocks before quantization
|
||||
// to spread outlier energy uniformly across the shared-exponent group.
|
||||
//
|
||||
// Without rotation, a single outlier in a block of 32 forces the shared E8M0
|
||||
// exponent high, wasting precision for all 31 other elements. The Hadamard
|
||||
// transform is orthogonal (H^T·H = I), so H(K)·H(Q) = K·Q — attention scores
|
||||
// are preserved exactly when both K and Q undergo the same rotation.
|
||||
//
|
||||
// Implementation: 5 butterfly stages (log2(32) = 5) of the fast Walsh-Hadamard
|
||||
// transform, followed by normalization by 1/sqrt(32). Total: 160 FP add/sub +
|
||||
// 32 FP mul. This is the standard "in-place" FWHT with O(n·log(n)) operations.
|
||||
//
|
||||
// The 1/sqrt(32) normalization factor makes the transform orthonormal:
|
||||
// H_normalized = H_unnormalized / sqrt(N)
|
||||
// This ensures the transform preserves vector norms (energy), which is critical
|
||||
// for maintaining attention score magnitudes after rotation.
|
||||
//
|
||||
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024) apply Hadamard
|
||||
// for weight quantization. Our novel contribution: applying it to KV cache
|
||||
// quantization at the MX block boundary (block-32), where it matches the shared
|
||||
// exponent group size. Tested alternatives (block-8, block-16, sign flips,
|
||||
// permutations) all degraded quality — block-32 Hadamard is uniquely optimal
|
||||
// because it spreads energy across exactly the elements sharing an exponent.
|
||||
//
|
||||
// Empirical PPL impact WITHOUT Hadamard rotation (Qwen3-Coder-30B-A3B):
|
||||
// MXFP8 E4M3: +0.22, MXFP8 E5M2: +1.38, MXFP6 E2M3: +3.34, MXFP6 E3M2: +4.60
|
||||
// ====================== Hadamard rotation
|
||||
//
|
||||
// 32-element Walsh-Hadamard transform applied before MX quantization to spread
|
||||
// outlier energy across the shared-exponent group. Orthogonal (H^T·H = I), so
|
||||
// H(K)·H(Q) = K·Q — attention scores are preserved when both K and Q are rotated.
|
||||
// Prior art: QuIP# (Tseng et al. 2024), BRQ (Huang et al. 2024)
|
||||
|
||||
void ggml_hadamard_32_inplace(float vals[32]) {
|
||||
ggml_mxfp_hadamard_32_inplace(vals);
|
||||
}
|
||||
|
|
@ -697,9 +586,7 @@ uint8_t float_to_fp8_e5m2_rn(float x) { return ggml_mxfp_float_to_fp8_e5m2(x); }
|
|||
void pack_fp6x4(const uint8_t v[4], uint8_t out[3]) { ggml_mxfp_pack_fp6x4(v, out); }
|
||||
void unpack_fp6x4(const uint8_t in[3], uint8_t v[4]) { ggml_mxfp_unpack_fp6x4(in, v); }
|
||||
|
||||
// MSE error functions for FP8/FP6: quantize at given scale → dequantize → squared error.
|
||||
// Used by mxfp_compute_e8m0_mse() to evaluate candidate E8M0 exponents.
|
||||
// These call the public API wrappers which delegate to canonical ggml_mxfp_* in ggml-common.h.
|
||||
// round-trip quantization error for MSE-optimal exponent search
|
||||
static float mse_error_fp8_e4m3(float val, float inv_scale, float scale) {
|
||||
const float recon = fp8_e4m3_to_float(float_to_fp8_e4m3_rn(val * inv_scale)) * scale;
|
||||
const float err = val - recon;
|
||||
|
|
@ -710,14 +597,9 @@ static float mse_error_fp6_e2m3(float val, float inv_scale, float scale) {
|
|||
const float err = val - recon;
|
||||
return err * err;
|
||||
}
|
||||
// emax_offset = ceil(log2(max_finite_value)) for each element type.
|
||||
// This centers the E8M0 exponent search around the optimal scale for the type's range.
|
||||
// E4M3: max=448, ceil(log2(448)) = 9, but offset=8 matches CUDA (empirically better)
|
||||
// E2M3: max=7.5, ceil(log2(7.5)) = 3
|
||||
static const mxfp_elem_traits_t mxfp8_e4m3_traits = { MXFP8_E4M3_EMAX_OFFSET, float_to_fp8_e4m3_rn, fp8_e4m3_to_float, mse_error_fp8_e4m3 };
|
||||
static const mxfp_elem_traits_t mxfp6_e2m3_traits = { MXFP6_E2M3_EMAX_OFFSET, float_to_fp6_e2m3_rn, fp6_e2m3_to_float, mse_error_fp6_e2m3 };
|
||||
|
||||
// FP8 quantize/dequantize: byte-per-element, shared by E4M3 and E5M2
|
||||
static void quantize_row_mxfp8_impl(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
|
|
@ -748,7 +630,6 @@ static void dequantize_row_mxfp8_impl(const block_mxfp8 * GGML_RESTRICT x, float
|
|||
}
|
||||
}
|
||||
|
||||
// FP6 quantize/dequantize: tight 6-bit packing (4 values per 3 bytes), shared by E2M3 and E3M2
|
||||
static void quantize_row_mxfp6_impl(const float * GGML_RESTRICT x, block_mxfp6 * GGML_RESTRICT y,
|
||||
int64_t k, const mxfp_elem_traits_t * traits) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
|
|
@ -787,8 +668,6 @@ static void dequantize_row_mxfp6_impl(const block_mxfp6 * GGML_RESTRICT x, float
|
|||
}
|
||||
}
|
||||
|
||||
// Public API wrappers — one-line delegates to the traits-parameterized impl
|
||||
|
||||
void quantize_row_mxfp8_ref(const float * GGML_RESTRICT x, block_mxfp8 * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_mxfp8_impl(x, y, k, &mxfp8_e4m3_traits);
|
||||
}
|
||||
|
|
@ -805,13 +684,9 @@ void dequantize_row_mxfp6(const block_mxfp6 * GGML_RESTRICT x, float * GGML_REST
|
|||
dequantize_row_mxfp6_impl(x, y, k, &mxfp6_e2m3_traits);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SoA (Struct-of-Arrays) quantize/dequantize — canonical reference for FA
|
||||
// ============================================================================
|
||||
// ====================== SoA (Struct-of-Arrays) quantize/dequantize for flash attention
|
||||
//
|
||||
// SoA layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N]
|
||||
// Total bytes per row = nblocks * (QS_PER_BLOCK + 1) = identical to AoS.
|
||||
// This is the ONLY layout used by flash attention across all backends.
|
||||
// Layout per row: [qs_block0|qs_block1|...|qs_blockN][e8m0_0|e8m0_1|...|e8m0_N]
|
||||
|
||||
void quantize_row_mxfp4_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k) {
|
||||
assert(k % QK_MXFP4 == 0);
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ llama_kv_cache::llama_kv_cache(
|
|||
const bool has_k = true;
|
||||
const bool has_v = !is_mla;
|
||||
|
||||
// MXFP K cache: align block count to 16 for cp.async.
|
||||
// MXFP: align block count to 16 for cp.async
|
||||
uint32_t n_embd_k_alloc = n_embd_k_gqa;
|
||||
const bool is_mxfp_k = ggml_is_type_mxfp(type_k);
|
||||
if (is_mxfp_k) {
|
||||
|
|
@ -1037,8 +1037,7 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
|
|||
|
||||
auto * k = layers[ikv].k;
|
||||
|
||||
// For MXFP types: k->ne[0] may include alignment padding (blocks aligned to 16).
|
||||
// The row stride (k->nb[1]) reflects the padded allocation.
|
||||
// note: for MXFP types, k->ne[0] may be padded for block alignment; use nb[] for strides
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
return ggml_view_4d(ctx, k,
|
||||
|
|
@ -1107,14 +1106,13 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
|
|||
assert(kv_size == k->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
// Use view_2d to preserve nb[1] (which includes alignment padding for MXFP types)
|
||||
// note: use view_2d to preserve nb[1] (includes MXFP alignment padding)
|
||||
k = ggml_view_2d(ctx, k, k->ne[0], kv_size*n_stream, k->nb[1], 0);
|
||||
}
|
||||
|
||||
const bool is_mxfp = ggml_is_type_mxfp(k->type);
|
||||
|
||||
// For MXFP: ne[0] may be padded for block alignment, but k_cur has n_embd_gqa.
|
||||
// Create view with ne[0]=n_embd_gqa, preserving the larger row stride nb[1].
|
||||
// for MXFP: ne[0] may be padded, narrow view to n_embd_gqa while keeping row stride
|
||||
ggml_tensor * k_dst = k;
|
||||
if (is_mxfp) {
|
||||
k_dst = ggml_view_2d(ctx, k, n_embd_gqa, k->ne[1], k->nb[1], 0);
|
||||
|
|
@ -1123,11 +1121,8 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
|
|||
// store the current K values into the cache
|
||||
ggml_tensor * result = ggml_set_rows(ctx, k_dst, k_cur, k_idxs);
|
||||
|
||||
// Flag K cache writes for Walsh-Hadamard rotation (QuaRot, arXiv:2404.00456; BRQ, arXiv:2511.04214).
|
||||
// The flash attention kernel applies matching rotation to Q so H(Q)·H(K)^T = Q·K^T.
|
||||
// V cache writes are NOT rotated (op_params[0] defaults to 0).
|
||||
// Skipped for: MLA (V is a view of K — rotation would corrupt V),
|
||||
// E5M2/E3M2 (2-bit mantissa — Hadamard provides no quality benefit).
|
||||
// enable Hadamard rotation for MXFP K cache (QuaRot arXiv:2404.00456, BRQ arXiv:2511.04214)
|
||||
// skipped for MLA (V is a view of K) and E5M2/E3M2 (2-bit mantissa, no benefit)
|
||||
if (is_mxfp && !hparams.is_mla() && ggml_mxfp_use_hadamard(k->type)) {
|
||||
((int32_t *)result->op_params)[0] = 1;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue