Add forward and backward error

PiperOrigin-RevId: 678297584
This commit is contained in:
Jan Wassenberg 2024-09-24 10:09:50 -07:00 committed by Copybara-Service
parent 673673cc98
commit e70e686805
1 changed files with 87 additions and 7 deletions

View File

@ -611,15 +611,25 @@ class DotStats {
s_cond.Notify(cond);
const float mul_tol = cond > 1E8 ? 1.5f : cond > 1E7 ? 1.1f : 1.01f;
float muls[kVariants];
float l1s[kVariants];
float muls[kVariants]; // ratio
float l1s[kVariants]; // abs error
float rels[kVariants]; // relative forward error
float bwds[kVariants]; // backward error
int bits[kVariants]; // 'bits correct'
uint32_t ulps[kVariants];
for (size_t i = 0; i < kVariants; ++i) {
muls[i] = Ratio(dots[i], dot_exact);
max_muls[i] = HWY_MAX(max_muls[i], muls[i]);
l1s[i] = std::abs(dots[i] - dot_exact);
const float abs_dot = hwy::ScalarAbs(dots[i]);
rels[i] = l1s[i] / HWY_MAX(abs_dot, 1E-6f); // avoid infinity
bwds[i] = rels[i] / cond;
bits[i] = HWY_MIN(-std::log2(rels[i]), hwy::MantissaBits<float>());
s_l1s[i].Notify(l1s[i]);
s_rels[i].Notify(rels[i]);
s_bwds[i].Notify(bwds[i]);
s_bits[i].Notify(bits[i]);
ulps[i] = hwy::detail::ComputeUlpDelta(dots[i], dot_exact);
s_ulps[i].Notify(ulps[i]);
@ -629,8 +639,8 @@ class DotStats {
muls[kNaive] + 1E-3f < muls[kKahan] || ulps[kCompensated] > 10) {
fprintf(stderr, "num %2zu cond %.1E exact %.8f\n", num, cond, dot_exact);
for (size_t i = 0; i < kVariants; ++i) {
fprintf(stderr, " %9s dot %11.8f mul %.8f\n", VariantName(i), dots[i],
muls[i]);
fprintf(stderr, " %9s dot %11.8f mul %.8f rel %f bwd %f bits %d\n",
VariantName(i), dots[i], muls[i], rels[i], bwds[i], bits[i]);
}
}
}
@ -648,11 +658,15 @@ class DotStats {
}
}
// Forward to all members' Assimilate().
void Assimilate(const DotStats& other) {
s_cond.Assimilate(other.s_cond);
for (size_t i = 0; i < kVariants; ++i) {
s_muls[i].Assimilate(other.s_muls[i]);
s_l1s[i].Assimilate(other.s_l1s[i]);
s_rels[i].Assimilate(other.s_rels[i]);
s_bwds[i].Assimilate(other.s_bwds[i]);
s_bits[i].Assimilate(other.s_bits[i]);
s_ulps[i].Assimilate(other.s_ulps[i]);
s_times[i].Assimilate(other.s_times[i]);
}
@ -666,6 +680,15 @@ class DotStats {
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats(" l1", variant, s_l1s[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("rel", variant, s_rels[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("bwd", variant, s_bwds[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("bits", variant, s_bits[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("ulp", variant, s_ulps[variant]);
}
@ -679,6 +702,9 @@ class DotStats {
void Check() const {
CheckMuls();
CheckL1();
CheckRel();
CheckBwd();
// No need to check bits, it is a monotonic function of rel.
CheckUlps();
// We do not check times because they can be noisy/nonportable, but
@ -687,7 +713,7 @@ class DotStats {
}
private:
// Factor by which the approximate result is off; larger is worse.
// Factor by which the approximate result is off; lower is better.
void CheckMuls() const {
// Comp2 is between Compensated and Kahan.
ASSERT_INSIDE(kComp2, 1.001, s_muls[kComp2].Mean(), 1.3);
@ -719,7 +745,7 @@ class DotStats {
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.6);
}
// Absolute error; larger is worse.
// Absolute error; lower is better.
void CheckL1() const {
// Comp2 is between Compensated and Kahan.
ASSERT_INSIDE(kComp2, 1E-5, s_l1s[kComp2].Mean(), 9E-4);
@ -747,7 +773,58 @@ class DotStats {
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
}
// Units in the last place; larger is worse.
// Forward relative error, lower is better.
void CheckRel() const {
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.5E-3);
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f);
// Compensated is very accurate.
ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f);
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
// Naive and OnlyTwoProd are considerably higher, but not huge.
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
0.06);
// Kahan (FastTwoSum) is decent:
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3);
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
// TwoProducts and TwoSums are a bit better.
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
3E-3);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f);
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
2.4E-3);
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
// Extremely high error on aarch64.
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f);
}
// Backward relative error, lower is better.
void CheckBwd() const {
ASSERT_INSIDE(kComp2, 7E-10f, s_rels[kComp2].Max(), 0.4f);
// Compensated is very accurate.
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
// Naive and OnlyTwoProd are considerably higher than others
ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
// Kahan (FastTwoSum) is not much better here!
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
// But TwoProducts/TwoSums help a bit.
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f);
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
// Extremely high error on aarch64.
ASSERT_INSIDE(kPairwise, 7E-10f, s_rels[kPairwise].Max(), 1250.f);
}
// Units in the last place; lower is better.
void CheckUlps() const {
ASSERT_LESS(kComp2, s_ulps[kCompensated].Max(), 3.6E6f);
ASSERT_LESS(kCompensated, s_ulps[kCompensated].Max(), 250.0f);
@ -766,6 +843,9 @@ class DotStats {
hwy::Stats s_muls[kVariants];
hwy::Stats s_l1s[kVariants]; // Absolute error
hwy::Stats s_rels[kVariants]; // forward relative
hwy::Stats s_bwds[kVariants]; // = forward / condition number
hwy::Stats s_bits[kVariants]; // = -log2(rel), capped to 23
hwy::Stats s_ulps[kVariants]; // Only relevant for small cond
hwy::Stats s_times[kVariants];