From 5a71d819cbb42af18d841cc77b9b17eaf5f23de9 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 4 Oct 2024 07:11:27 -0700 Subject: [PATCH] Also enable f64 dot/sum for , float>(); } -// TODO: `DotKernelDouble` requires both inputs to be `float` because currently -// only `CompressTraits` can `Decompress2` to `double`. It is not yet -// clear whether we want to implement this for other Packed types. -template -constexpr bool CanDecompressToDouble() { - return HWY_HAVE_FLOAT64 && IsF32() && IsF32(); -} - namespace detail { // Compile-time-only check that `DRaw` and `Packed` are compatible. This makes diff --git a/ops/dot-inl.h b/ops/dot-inl.h index f8e0a4f..82935f7 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -157,13 +157,17 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) { return cond; } -// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed -// of f32 FMA. Only usable if `CanDecompressToDouble()`. +// f64 FMA. Inputs are both f32 promoted to f64, or any types that are either +// promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA. struct DotKernelDouble { + // Only `CompressTraits` can `Decompress2` to `double`, so both have + // to be `float` in order to have `Raw = double`. Note that if either type is + // smaller than `float`, we may demote the other type from `float` to `BF16`. template - using Raw = double; + using Raw = hwy::If() && IsF32(), double, BF16>; using State = double; + // Raw = double template , HWY_IF_F64_D(DRaw)> HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2, const VR w3, const VR v0, const VR v1, const VR v2, @@ -175,18 +179,77 @@ struct DotKernelDouble { sum3 = hn::MulAdd(w3, v3, sum3); } + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2, + const VR w3, const VR v0, const VR v1, const VR v2, + const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS&, VS&, VS&, VS&) const { + const hn::Repartition df; + using VF = hn::Vec; + const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0); + const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1); + // Reduce to two f32 sums so we can promote them to four f64 vectors. + VF sum02, sum13; + if constexpr (HWY_NATIVE_DOT_BF16) { + // Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate. + VF unused0 = hn::Zero(df); + VF unused1 = hn::Zero(df); + sum02 = hn::ReorderWidenMulAccumulate(df, w2, v2, prod0, unused0); + sum13 = hn::ReorderWidenMulAccumulate(df, w3, v3, prod1, unused1); + } else { + // ReorderWidenMulAccumulate does not help because we still end up with + // four accumulators. + const VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2); + const VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3); + sum02 = hn::Add(prod0, prod2); + sum13 = hn::Add(prod1, prod3); + } + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, sum02); + const VS d1 = hn::PromoteUpperTo(ds, sum02); + const VS d2 = hn::PromoteLowerTo(ds, sum13); + const VS d3 = hn::PromoteUpperTo(ds, sum13); + + sum0 = hn::Add(sum0, d0); + sum1 = hn::Add(sum1, d1); + sum2 = hn::Add(sum2, d2); + sum3 = hn::Add(sum3, d3); + } + + // Raw = double template , HWY_IF_F64_D(DRaw)> HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0, VR&) const { sum0 = hn::MulAdd(w0, v0, sum0); } + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0, + VS& extra0) const { + const hn::Repartition df; + using VF = hn::Vec; + const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0); + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, prod0); + const VS d1 = hn::PromoteUpperTo(ds, prod0); + + sum0 = hn::Add(sum0, d0); + extra0 = hn::Add(extra0, d1); + } + template , HWY_IF_F64_D(DState)> HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, - VS&, VS&, VS&, VS&) const { + VS& extra0, VS&, VS&, VS&) const { // Reduction tree: sum of all accumulators by pairs, then across lanes. sum0 = hn::Add(sum0, sum1); sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, extra0); // from Update1 sum0 = hn::Add(sum0, sum2); return static_cast(hn::ReduceSum(dd, sum0)); } @@ -198,12 +261,13 @@ HWY_INLINE float DotDouble(D d, const PackedSpan& w, size_t w_ofs, return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble()); } -// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower -// than DotKernelDouble and about equally accurate. +// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This about as +// accurate as DotKernelDouble but slower, hence we only use this if f64 is +// not supported on this target. struct DotKernelCompensated { - // Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h. - // Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than - // promoting to `float` because it does not call `TwoProducts`. + // The `BF16` overload uses `ReorderWidenMulAccumulate`, which requires both + // `VT` and `WT` to be `BF16`, or smaller types decompressed to `BF16`. + // Otherwise, we decompress both inputs to `float`. template using Raw = hwy::If() || IsF32(), float, BF16>; using State = float; @@ -240,10 +304,10 @@ struct DotKernelCompensated { const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, VS& comp0, VS& comp1, VS& comp2, VS& comp3) const { const DS df; - const VS prod0 = WidenMulPairwiseAdd(df, w0, v0); - const VS prod1 = WidenMulPairwiseAdd(df, w1, v1); - const VS prod2 = WidenMulPairwiseAdd(df, w2, v2); - const VS prod3 = WidenMulPairwiseAdd(df, w3, v3); + const VS prod1 = hn::WidenMulPairwiseAdd(df, w1, v1); + const VS prod2 = hn::WidenMulPairwiseAdd(df, w2, v2); + const VS prod3 = hn::WidenMulPairwiseAdd(df, w3, v3); + const VS prod0 = hn::WidenMulPairwiseAdd(df, w0, v0); VS serr0, serr1, serr2, serr3; sum0 = TwoSums(df, prod0, sum0, serr0); @@ -295,16 +359,14 @@ struct DotKernelCompensated { } }; -template -using DotKernelDefault = hwy::If(), - DotKernelDouble, DotKernelCompensated>; +using DotKernelDefault = + hwy::If; // `D` only serves to specify the vector size; its lane type is ignored. template HWY_INLINE float Dot(D d, const PackedSpan& w, size_t w_ofs, const VT* HWY_RESTRICT vec, size_t num) { - return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), - DotKernelDefault()); + return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault()); } // Adapter for two pointers, no bounds checking. diff --git a/ops/dot_test.cc b/ops/dot_test.cc index f39088b..42efa01 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -507,10 +507,10 @@ struct DotKernelComp2 { VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { const DF df; - VF prod0 = WidenMulPairwiseAdd(df, w0, v0); - VF prod1 = WidenMulPairwiseAdd(df, w1, v1); - VF prod2 = WidenMulPairwiseAdd(df, w2, v2); - VF prod3 = WidenMulPairwiseAdd(df, w3, v3); + VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0); + VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1); + VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2); + VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3); // Pairwise sums prod0 = hn::Add(prod0, prod1); @@ -564,11 +564,6 @@ HWY_INLINE float DotComp2(D d, const PackedSpan& w, size_t w_ofs, template float CallDot(D d, size_t variant, const PackedSpan& w, size_t w_ofs, const VT* HWY_RESTRICT v, size_t num) { - // float inputs also support kDouble. - if constexpr (CanDecompressToDouble()) { - if (variant == kDouble) return DotDouble(d, w, 0, v, num); - } - switch (variant) { case kAddTwoProd: return DotTwoProdFast(d, w, 0, v, num); @@ -578,6 +573,12 @@ float CallDot(D d, size_t variant, const PackedSpan& w, size_t w_ofs, return DotComp2(d, w, 0, v, num); case kCompensated: return DotCompensated(d, w, 0, v, num); + case kDouble: + if constexpr (HWY_HAVE_FLOAT64) { + return DotDouble(d, w, 0, v, num); + } else { + return DotCompensated(d, w, 0, v, num); + } case kKahan: return DotKahan(d, w, 0, v, num); case kNaive: diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 6e6c674..5468122 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -142,7 +142,9 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); - const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); + // Dot(float,BF16) rounds both to BF16. + using RefType = hwy::If() && IsF32(), float, BF16>; + const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); const double tolerance = 200.0 * norm * epsilon; for (size_t idx = 0; idx < num_c; idx++) { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 29fff5b..adb06fe 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -146,8 +146,7 @@ namespace detail { template float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { const hn::ScalableTag d; - const float l2 = - DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); + const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); constexpr float kEps = 1e-6f; // avoid divide by zero return 1.0f / sqrtf(l2 / StaticCast(size) + kEps); } @@ -506,12 +505,16 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( } // f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed -// of f32 sums. Only usable if `CanDecompressToDouble()`. +// of f32 sums. struct SumKernelDouble { + // Only `CompressTraits` can `Decompress2` to `double`, so both have + // to be `float` in order to have `Raw = double`. Note that if either type is + // smaller than `float`, we may demote the other type from `float` to `BF16`. template - using Raw = double; + using Raw = hwy::If() && IsF32(), double, BF16>; using State = double; + // Raw = double template , HWY_IF_F64_D(DRaw)> HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2, const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1, @@ -522,18 +525,95 @@ struct SumKernelDouble { sum3 = hn::Add(sum3, w3); } + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, + const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1, + VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const { + const hn::Repartition df; + using VF = hn::Vec; + // Reduce to two f32 sums so we can promote them to four f64 vectors. + VF sum02, sum13; + if constexpr (HWY_NATIVE_DOT_BF16) { + const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); + const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1); + const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1); + // Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate. + VF unused0 = hn::Zero(df); + VF unused1 = hn::Zero(df); + sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0); + sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1); + } else { + // If not native, the multiplication costs extra, so convert to f32. + // PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`. + const VF fe0 = hn::PromoteEvenTo(df, w0); + const VF fe1 = hn::PromoteEvenTo(df, w1); + const VF fe2 = hn::PromoteEvenTo(df, w2); + const VF fe3 = hn::PromoteEvenTo(df, w3); + const VF fo0 = hn::PromoteOddTo(df, w0); + const VF fo1 = hn::PromoteOddTo(df, w1); + const VF fo2 = hn::PromoteOddTo(df, w2); + const VF fo3 = hn::PromoteOddTo(df, w3); + const VF fe01 = hn::Add(fe0, fe1); + const VF fe23 = hn::Add(fe2, fe3); + const VF fo01 = hn::Add(fo0, fo1); + const VF fo23 = hn::Add(fo2, fo3); + sum02 = hn::Add(fe01, fe23); + sum13 = hn::Add(fo01, fo23); + } + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, sum02); + const VS d1 = hn::PromoteUpperTo(ds, sum02); + const VS d2 = hn::PromoteLowerTo(ds, sum13); + const VS d3 = hn::PromoteUpperTo(ds, sum13); + + sum0 = hn::Add(sum0, d0); + sum1 = hn::Add(sum1, d1); + sum2 = hn::Add(sum2, d2); + sum3 = hn::Add(sum3, d3); + } + + // Raw = double template , HWY_IF_F64_D(DRaw)> HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0, VR& comp0) const { sum0 = hn::Add(sum0, w0); } + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0, + VS& extra0) const { + const hn::Repartition df; + using VF = hn::Vec; + VF f0; + if constexpr (HWY_NATIVE_DOT_BF16) { + const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); + f0 = hn::WidenMulPairwiseAdd(df, w0, k1); + } else { + const VF fe0 = hn::PromoteEvenTo(df, w0); + const VF fo0 = hn::PromoteOddTo(df, w0); + f0 = hn::Add(fe0, fo0); + } + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, f0); + const VS d1 = hn::PromoteUpperTo(ds, f0); + + sum0 = hn::Add(sum0, d0); + extra0 = hn::Add(extra0, d1); + } + template > HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, - VS&, VS&, VS&, VS&) const { + VS& extra0, VS&, VS&, VS&) const { // Reduction tree: sum of all accumulators by pairs, then across lanes. sum0 = hn::Add(sum0, sum1); sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, extra0); // from Update1 sum0 = hn::Add(sum0, sum2); return static_cast(hn::ReduceSum(dd, sum0)); } @@ -547,7 +627,8 @@ struct SumKernelDouble { // comp* to w*, so each operation is serially dependent. By contrast, the sum* // and comp* here have shorter dependency chains. // -// This is slower than SumKernelDouble and about equally accurate. +// This about as accurate as SumKernelDouble but slower, hence we only use this +// if f64 is not supported on this target. struct SumKernelCascaded { template using Raw = float; @@ -590,15 +671,14 @@ struct SumKernelCascaded { } }; -template -using SumKernelDefault = hwy::If(), - SumKernelDouble, SumKernelCascaded>; +using SumKernelDefault = + hwy::If; template HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { using Raw = hwy::If; const hn::Repartition d_raw; - return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); + return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); } // See below for a specialized version for top-1 sampling.