diff --git a/ops/dot_test.cc b/ops/dot_test.cc index f24467c..24e2edb 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -611,15 +611,25 @@ class DotStats { s_cond.Notify(cond); const float mul_tol = cond > 1E8 ? 1.5f : cond > 1E7 ? 1.1f : 1.01f; - float muls[kVariants]; - float l1s[kVariants]; + float muls[kVariants]; // ratio + float l1s[kVariants]; // abs error + float rels[kVariants]; // relative forward error + float bwds[kVariants]; // backward error + int bits[kVariants]; // 'bits correct' uint32_t ulps[kVariants]; for (size_t i = 0; i < kVariants; ++i) { muls[i] = Ratio(dots[i], dot_exact); max_muls[i] = HWY_MAX(max_muls[i], muls[i]); l1s[i] = std::abs(dots[i] - dot_exact); + const float abs_dot = hwy::ScalarAbs(dots[i]); + rels[i] = l1s[i] / HWY_MAX(abs_dot, 1E-6f); // avoid infinity + bwds[i] = rels[i] / cond; + bits[i] = HWY_MIN(-std::log2(rels[i]), hwy::MantissaBits()); s_l1s[i].Notify(l1s[i]); + s_rels[i].Notify(rels[i]); + s_bwds[i].Notify(bwds[i]); + s_bits[i].Notify(bits[i]); ulps[i] = hwy::detail::ComputeUlpDelta(dots[i], dot_exact); s_ulps[i].Notify(ulps[i]); @@ -629,8 +639,8 @@ class DotStats { muls[kNaive] + 1E-3f < muls[kKahan] || ulps[kCompensated] > 10) { fprintf(stderr, "num %2zu cond %.1E exact %.8f\n", num, cond, dot_exact); for (size_t i = 0; i < kVariants; ++i) { - fprintf(stderr, " %9s dot %11.8f mul %.8f\n", VariantName(i), dots[i], - muls[i]); + fprintf(stderr, " %9s dot %11.8f mul %.8f rel %f bwd %f bits %d\n", + VariantName(i), dots[i], muls[i], rels[i], bwds[i], bits[i]); } } } @@ -648,11 +658,15 @@ class DotStats { } } + // Forward to all members' Assimilate(). void Assimilate(const DotStats& other) { s_cond.Assimilate(other.s_cond); for (size_t i = 0; i < kVariants; ++i) { s_muls[i].Assimilate(other.s_muls[i]); s_l1s[i].Assimilate(other.s_l1s[i]); + s_rels[i].Assimilate(other.s_rels[i]); + s_bwds[i].Assimilate(other.s_bwds[i]); + s_bits[i].Assimilate(other.s_bits[i]); s_ulps[i].Assimilate(other.s_ulps[i]); s_times[i].Assimilate(other.s_times[i]); } @@ -666,6 +680,15 @@ class DotStats { for (size_t variant = 0; variant < kVariants; ++variant) { PrintStats(" l1", variant, s_l1s[variant]); } + for (size_t variant = 0; variant < kVariants; ++variant) { + PrintStats("rel", variant, s_rels[variant]); + } + for (size_t variant = 0; variant < kVariants; ++variant) { + PrintStats("bwd", variant, s_bwds[variant]); + } + for (size_t variant = 0; variant < kVariants; ++variant) { + PrintStats("bits", variant, s_bits[variant]); + } for (size_t variant = 0; variant < kVariants; ++variant) { PrintStats("ulp", variant, s_ulps[variant]); } @@ -679,6 +702,9 @@ class DotStats { void Check() const { CheckMuls(); CheckL1(); + CheckRel(); + CheckBwd(); + // No need to check bits, it is a monotonic function of rel. CheckUlps(); // We do not check times because they can be noisy/nonportable, but @@ -687,7 +713,7 @@ class DotStats { } private: - // Factor by which the approximate result is off; larger is worse. + // Factor by which the approximate result is off; lower is better. void CheckMuls() const { // Comp2 is between Compensated and Kahan. ASSERT_INSIDE(kComp2, 1.001, s_muls[kComp2].Mean(), 1.3); @@ -719,7 +745,7 @@ class DotStats { ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.6); } - // Absolute error; larger is worse. + // Absolute error; lower is better. void CheckL1() const { // Comp2 is between Compensated and Kahan. ASSERT_INSIDE(kComp2, 1E-5, s_l1s[kComp2].Mean(), 9E-4); @@ -747,7 +773,58 @@ class DotStats { ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f); } - // Units in the last place; larger is worse. + // Forward relative error, lower is better. + void CheckRel() const { + ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.5E-3); + ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f); + + // Compensated is very accurate. + ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f); + ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f); + + // Naive and OnlyTwoProd are considerably higher, but not huge. + ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2); + ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(), + 0.06); + + // Kahan (FastTwoSum) is decent: + ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3); + ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f); + + // TwoProducts and TwoSums are a bit better. + ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(), + 3E-3); + ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f); + ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(), + 2.4E-3); + + ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2); + // Extremely high error on aarch64. + ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f); + } + + // Backward relative error, lower is better. + void CheckBwd() const { + ASSERT_INSIDE(kComp2, 7E-10f, s_rels[kComp2].Max(), 0.4f); + + // Compensated is very accurate. + ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f); + + // Naive and OnlyTwoProd are considerably higher than others + ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f); + ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 3080.f); + // Kahan (FastTwoSum) is not much better here! + ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f); + + // But TwoProducts/TwoSums help a bit. + ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f); + ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f); + + // Extremely high error on aarch64. + ASSERT_INSIDE(kPairwise, 7E-10f, s_rels[kPairwise].Max(), 1250.f); + } + + // Units in the last place; lower is better. void CheckUlps() const { ASSERT_LESS(kComp2, s_ulps[kCompensated].Max(), 3.6E6f); ASSERT_LESS(kCompensated, s_ulps[kCompensated].Max(), 250.0f); @@ -766,6 +843,9 @@ class DotStats { hwy::Stats s_muls[kVariants]; hwy::Stats s_l1s[kVariants]; // Absolute error + hwy::Stats s_rels[kVariants]; // forward relative + hwy::Stats s_bwds[kVariants]; // = forward / condition number + hwy::Stats s_bits[kVariants]; // = -log2(rel), capped to 23 hwy::Stats s_ulps[kVariants]; // Only relevant for small cond hwy::Stats s_times[kVariants];