Merge c963a4fbef into 88915cb55c
This commit is contained in:
commit
18eb012fa9
|
|
@ -325,7 +325,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE
|
|||
}
|
||||
|
||||
// UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax
|
||||
const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
|
||||
const uint8_t first_ue = ggml_fp32_to_ue4m3(amax / 6.0f);
|
||||
uint8_t ue = first_ue;
|
||||
float lowest_err = INFINITY;
|
||||
for (int difference = -2; difference <= 2; ++difference) {
|
||||
const int candidate = (int) first_ue + difference;
|
||||
if (candidate < 0 || candidate > 0x7E) {
|
||||
continue;
|
||||
}
|
||||
const float test_scale = ggml_ue4m3_to_fp32((uint8_t) candidate);
|
||||
float test_scale_error = 0.0f;
|
||||
for (int j = 0; j < qk_sub; ++j) {
|
||||
const int qi = best_index_mxfp4(xb[j], test_scale);
|
||||
const float err = xb[j] - kvalues_mxfp4[qi] * test_scale;
|
||||
test_scale_error += err * err;
|
||||
}
|
||||
if (test_scale_error < lowest_err) {
|
||||
lowest_err = test_scale_error;
|
||||
ue = (uint8_t) candidate;
|
||||
}
|
||||
}
|
||||
y[i].d[s] = ue;
|
||||
const float d = ggml_ue4m3_to_fp32(ue);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue