mirror of https://github.com/google/gemma.cpp.git
Add forward and backward error
PiperOrigin-RevId: 678297584
This commit is contained in:
parent
673673cc98
commit
e70e686805
|
|
@ -611,15 +611,25 @@ class DotStats {
|
|||
s_cond.Notify(cond);
|
||||
const float mul_tol = cond > 1E8 ? 1.5f : cond > 1E7 ? 1.1f : 1.01f;
|
||||
|
||||
float muls[kVariants];
|
||||
float l1s[kVariants];
|
||||
float muls[kVariants]; // ratio
|
||||
float l1s[kVariants]; // abs error
|
||||
float rels[kVariants]; // relative forward error
|
||||
float bwds[kVariants]; // backward error
|
||||
int bits[kVariants]; // 'bits correct'
|
||||
uint32_t ulps[kVariants];
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
muls[i] = Ratio(dots[i], dot_exact);
|
||||
max_muls[i] = HWY_MAX(max_muls[i], muls[i]);
|
||||
|
||||
l1s[i] = std::abs(dots[i] - dot_exact);
|
||||
const float abs_dot = hwy::ScalarAbs(dots[i]);
|
||||
rels[i] = l1s[i] / HWY_MAX(abs_dot, 1E-6f); // avoid infinity
|
||||
bwds[i] = rels[i] / cond;
|
||||
bits[i] = HWY_MIN(-std::log2(rels[i]), hwy::MantissaBits<float>());
|
||||
s_l1s[i].Notify(l1s[i]);
|
||||
s_rels[i].Notify(rels[i]);
|
||||
s_bwds[i].Notify(bwds[i]);
|
||||
s_bits[i].Notify(bits[i]);
|
||||
|
||||
ulps[i] = hwy::detail::ComputeUlpDelta(dots[i], dot_exact);
|
||||
s_ulps[i].Notify(ulps[i]);
|
||||
|
|
@ -629,8 +639,8 @@ class DotStats {
|
|||
muls[kNaive] + 1E-3f < muls[kKahan] || ulps[kCompensated] > 10) {
|
||||
fprintf(stderr, "num %2zu cond %.1E exact %.8f\n", num, cond, dot_exact);
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
fprintf(stderr, " %9s dot %11.8f mul %.8f\n", VariantName(i), dots[i],
|
||||
muls[i]);
|
||||
fprintf(stderr, " %9s dot %11.8f mul %.8f rel %f bwd %f bits %d\n",
|
||||
VariantName(i), dots[i], muls[i], rels[i], bwds[i], bits[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -648,11 +658,15 @@ class DotStats {
|
|||
}
|
||||
}
|
||||
|
||||
// Forward to all members' Assimilate().
|
||||
void Assimilate(const DotStats& other) {
|
||||
s_cond.Assimilate(other.s_cond);
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
s_muls[i].Assimilate(other.s_muls[i]);
|
||||
s_l1s[i].Assimilate(other.s_l1s[i]);
|
||||
s_rels[i].Assimilate(other.s_rels[i]);
|
||||
s_bwds[i].Assimilate(other.s_bwds[i]);
|
||||
s_bits[i].Assimilate(other.s_bits[i]);
|
||||
s_ulps[i].Assimilate(other.s_ulps[i]);
|
||||
s_times[i].Assimilate(other.s_times[i]);
|
||||
}
|
||||
|
|
@ -666,6 +680,15 @@ class DotStats {
|
|||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats(" l1", variant, s_l1s[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("rel", variant, s_rels[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("bwd", variant, s_bwds[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("bits", variant, s_bits[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("ulp", variant, s_ulps[variant]);
|
||||
}
|
||||
|
|
@ -679,6 +702,9 @@ class DotStats {
|
|||
void Check() const {
|
||||
CheckMuls();
|
||||
CheckL1();
|
||||
CheckRel();
|
||||
CheckBwd();
|
||||
// No need to check bits, it is a monotonic function of rel.
|
||||
CheckUlps();
|
||||
|
||||
// We do not check times because they can be noisy/nonportable, but
|
||||
|
|
@ -687,7 +713,7 @@ class DotStats {
|
|||
}
|
||||
|
||||
private:
|
||||
// Factor by which the approximate result is off; larger is worse.
|
||||
// Factor by which the approximate result is off; lower is better.
|
||||
void CheckMuls() const {
|
||||
// Comp2 is between Compensated and Kahan.
|
||||
ASSERT_INSIDE(kComp2, 1.001, s_muls[kComp2].Mean(), 1.3);
|
||||
|
|
@ -719,7 +745,7 @@ class DotStats {
|
|||
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.6);
|
||||
}
|
||||
|
||||
// Absolute error; larger is worse.
|
||||
// Absolute error; lower is better.
|
||||
void CheckL1() const {
|
||||
// Comp2 is between Compensated and Kahan.
|
||||
ASSERT_INSIDE(kComp2, 1E-5, s_l1s[kComp2].Mean(), 9E-4);
|
||||
|
|
@ -747,7 +773,58 @@ class DotStats {
|
|||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
|
||||
}
|
||||
|
||||
// Units in the last place; larger is worse.
|
||||
// Forward relative error, lower is better.
|
||||
void CheckRel() const {
|
||||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.5E-3);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f);
|
||||
|
||||
// Compensated is very accurate.
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f);
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
|
||||
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_rels[kOnlyTwoProd].GeometricMean(),
|
||||
0.06);
|
||||
|
||||
// Kahan (FastTwoSum) is decent:
|
||||
ASSERT_INSIDE(kKahan, 3E-4, s_rels[kKahan].GeometricMean(), 3.5E-3);
|
||||
ASSERT_INSIDE(kKahan, 6E-4f, s_rels[kKahan].Max(), 0.7f);
|
||||
|
||||
// TwoProducts and TwoSums are a bit better.
|
||||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_rels[kAddTwoProd].GeometricMean(),
|
||||
3E-3);
|
||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_rels[kAddTwoProd].Max(), 0.19f);
|
||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_rels[kAddTwoSum].GeometricMean(),
|
||||
2.4E-3);
|
||||
|
||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
|
||||
// Extremely high error on aarch64.
|
||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f);
|
||||
}
|
||||
|
||||
// Backward relative error, lower is better.
|
||||
void CheckBwd() const {
|
||||
ASSERT_INSIDE(kComp2, 7E-10f, s_rels[kComp2].Max(), 0.4f);
|
||||
|
||||
// Compensated is very accurate.
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher than others
|
||||
ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
|
||||
ASSERT_INSIDE(kOnlyTwoProd, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
|
||||
// Kahan (FastTwoSum) is not much better here!
|
||||
ASSERT_INSIDE(kKahan, 6E-10f, s_rels[kKahan].Max(), 0.7f);
|
||||
|
||||
// But TwoProducts/TwoSums help a bit.
|
||||
ASSERT_INSIDE(kAddTwoProd, 9E-10f, s_rels[kAddTwoProd].Max(), 0.19f);
|
||||
ASSERT_INSIDE(kAddTwoSum, 5E-10f, s_rels[kAddTwoSum].Max(), 0.34f);
|
||||
|
||||
// Extremely high error on aarch64.
|
||||
ASSERT_INSIDE(kPairwise, 7E-10f, s_rels[kPairwise].Max(), 1250.f);
|
||||
}
|
||||
|
||||
// Units in the last place; lower is better.
|
||||
void CheckUlps() const {
|
||||
ASSERT_LESS(kComp2, s_ulps[kCompensated].Max(), 3.6E6f);
|
||||
ASSERT_LESS(kCompensated, s_ulps[kCompensated].Max(), 250.0f);
|
||||
|
|
@ -766,6 +843,9 @@ class DotStats {
|
|||
hwy::Stats s_muls[kVariants];
|
||||
|
||||
hwy::Stats s_l1s[kVariants]; // Absolute error
|
||||
hwy::Stats s_rels[kVariants]; // forward relative
|
||||
hwy::Stats s_bwds[kVariants]; // = forward / condition number
|
||||
hwy::Stats s_bits[kVariants]; // = -log2(rel), capped to 23
|
||||
|
||||
hwy::Stats s_ulps[kVariants]; // Only relevant for small cond
|
||||
hwy::Stats s_times[kVariants];
|
||||
|
|
|
|||
Loading…
Reference in New Issue