diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cdaded865b..6ff7c3c036 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -325,7 +325,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE } // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax - const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f); + const uint8_t first_ue = ggml_fp32_to_ue4m3(amax / 6.0f); + uint8_t ue = first_ue; + float lowest_err = INFINITY; + for (int difference = -2; difference <= 2; ++difference) { + const int candidate = (int) first_ue + difference; + if (candidate < 0 || candidate > 0x7E) { + continue; + } + const float test_scale = ggml_ue4m3_to_fp32((uint8_t) candidate); + float test_scale_error = 0.0f; + for (int j = 0; j < qk_sub; ++j) { + const int qi = best_index_mxfp4(xb[j], test_scale); + const float err = xb[j] - kvalues_mxfp4[qi] * test_scale; + test_scale_error += err * err; + } + if (test_scale_error < lowest_err) { + lowest_err = test_scale_error; + ue = (uint8_t) candidate; + } + } y[i].d[s] = ue; const float d = ggml_ue4m3_to_fp32(ue);