From 5e812f07f554bcc75b85f2650a57968fa1a597c0 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 30 Sep 2024 01:21:37 -0700 Subject: [PATCH] Use f64 Dot and sum in softmax - faster than Cascaded Also let the kernel specify the Raw and State types, rename WeightT/VecT -> WT/VT. PiperOrigin-RevId: 680464427 --- compression/compress-inl.h | 194 ++++++++++++++++++++----------------- ops/dot-inl.h | 165 ++++++++++++++++++++----------- ops/dot_test.cc | 151 +++++++++++++---------------- ops/ops-inl.h | 64 ++++++++++-- 4 files changed, 342 insertions(+), 232 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 2d440f3..08c7c06 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -120,7 +120,6 @@ struct CompressTraits { BF16* HWY_RESTRICT raw, size_t num) { const hn::Repartition df; using VF = hn::Vec; - using VBF = hn::Vec; const size_t NF = hn::Lanes(df); size_t i = 0; @@ -169,7 +168,6 @@ struct CompressTraits { double* HWY_RESTRICT raw, size_t num) { const hn::Rebind df; using VF = hn::Vec; - using VD = hn::Vec; const size_t ND = hn::Lanes(dd); size_t i = 0; @@ -413,8 +411,6 @@ struct CompressTraits { static HWY_INLINE void Load2(D d, const PackedSpan& packed, const size_t packed_ofs, hn::Vec& raw0, hn::Vec& raw1) { - const hn::Twice> d8; - using V8 = hn::Vec; NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); } @@ -497,23 +493,39 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, Traits::Store2(df, raw0, raw1, packed, packed_ofs); } +template +constexpr bool IsF32() { + return hwy::IsSame, float>(); +} + +// TODO: `DotKernelDouble` requires both inputs to be `float` because currently +// only `CompressTraits` can `Decompress2` to `double`. It is not yet +// clear whether we want to implement this for other Packed types. +template +constexpr bool CanDecompressToDouble() { + return HWY_HAVE_FLOAT64 && IsF32() && IsF32(); +} + +namespace detail { + // Compile-time-only check that `DRaw` and `Packed` are compatible. This makes // for better error messages than "no matching function found". template -HWY_INLINE void VerifyRawAndPacked() { +HWY_INLINE void VerifyRawAndPackedForDecompress() { using TRaw = hn::TFromD; - constexpr bool kPackedF32 = hwy::IsSame, float>(); // We can decompress any Packed to f32 or BF16, or f32 to f64. static_assert(hwy::IsSameEither() || - (kPackedF32 && hwy::IsSame())); + (IsF32() && hwy::IsSame())); } +} // namespace detail + // Decompresses from any type of `packed`, to two vectors of `float/BF16`, or // `double`, if `Packed` is `float`. template > HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, const size_t packed_ofs, VRaw& raw0, VRaw& raw1) { - VerifyRawAndPacked(); + detail::VerifyRawAndPackedForDecompress(); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d)); using Traits = CompressTraits>; Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1); @@ -529,135 +541,139 @@ template > HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, const size_t packed_ofs, TRaw* raw, size_t num) { - VerifyRawAndPacked(); + detail::VerifyRawAndPackedForDecompress(); packed.BoundsCheck(packed_ofs, num); using Traits = CompressTraits>; Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); } -// Decompresses to the type specified by `D` from each of two arrays in groups -// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector -// multiple, then calls `kernel.Update1` for the remaining vectors. Returns -// `kernel.Reduce`. +// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from +// both into groups of four vectors with lane type `Kernel::Raw`, passes them to +// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes +// them to `kernel.Update1`, then returns `kernel.Reduce`. `v.num` is not +// required to be a multiple of the vector length. // -// This is useful for implementing dot products, and similar to -// `hwy/contrib/unroller`, but also supports compressed types with simpler -// remainder handling thanks to `DecompressAndZeroPad`. -// -// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float. -// `w` can be any packed type, including NUQ, which requires a separate `w_ofs` -// rather than pointer arithmetic. `vec` can also be any type, but typically -// float or BF16. We omit a `v_ofs` because it is 0 in our use cases. -// `num`, the number of elements to process, need not be a vector multiple. +// Both `w` and `v` can be any packed type. To support random access in `w` +// even if it is `NuqStream`, we ignore `w.num` and provide a `w_ofs`, but no +// `v_ofs` because it is always 0 in our use cases. `D` only serves to specify +// the vector size/fraction. // // `kernel` is const& so we can pass an rvalue argument, but can contain -// mutable state, though not vectors (see highway.h). We pass in the four -// loaded vectors plus eight state vectors. The state vectors' lane type is -// either `double` (required for DotKernelDouble) or `float`. -template -HWY_INLINE float DecompressAndCall(D d, const PackedSpan& w, +// mutable state, though not vectors (see highway.h). In addition to the groups +// of four input vectors, we pass eight state vectors with lane type specified +// by `Kernel::State`, which is typically `float` but may differ if `Raw` is +// `double`, or `WT` and `VT` are `BF16`. +// +// Decoupling decompression and remainder handling from the actual usage of the +// vectors makes it easier to implement various dot product and sum algorithms. +// This is similar to `hwy/contrib/unroller`, but less general and relies on +// `DecompressAndZeroPad`. +template +HWY_INLINE float DecompressAndCall(D, const PackedSpan& w, const size_t w_ofs, - const PackedSpan vec, + const PackedSpan v, const Kernel& kernel) { // Decompressed inputs - using T = hn::TFromD; - using V = hn::Vec; - V w0, w1, w2, w3, v0, v1, v2, v3; + using Raw = typename Kernel::template Raw; + const hn::Repartition d_raw; + using VRaw = hn::Vec; + VRaw w0, w1, w2, w3, v0, v1, v2, v3; // State for Kernel - using StateT = hwy::If(), double, float>; - const hn::Repartition ds; - using VS = hn::Vec; - VS sum0 = hn::Zero(ds); - VS sum1 = hn::Zero(ds); - VS sum2 = hn::Zero(ds); - VS sum3 = hn::Zero(ds); - VS comp0 = hn::Zero(ds); - VS comp1 = hn::Zero(ds); - VS comp2 = hn::Zero(ds); - VS comp3 = hn::Zero(ds); + const hn::Repartition d_state; + using VState = hn::Vec; + VState sum0 = hn::Zero(d_state); + VState sum1 = hn::Zero(d_state); + VState sum2 = hn::Zero(d_state); + VState sum3 = hn::Zero(d_state); + VState comp0 = hn::Zero(d_state); + VState comp1 = hn::Zero(d_state); + VState comp2 = hn::Zero(d_state); + VState comp3 = hn::Zero(d_state); - const size_t N = hn::Lanes(d); + const size_t N = hn::Lanes(d_raw); size_t i = 0; - if (vec.num >= 4 * N) { - for (; i <= vec.num - 4 * N; i += 4 * N) { - Decompress2(d, w, w_ofs + i + 0 * N, w0, w1); - Decompress2(d, w, w_ofs + i + 2 * N, w2, w3); - Decompress2(d, vec, i + 0 * N, v0, v1); - Decompress2(d, vec, i + 2 * N, v2, v3); + if (v.num >= 4 * N) { + for (; i <= v.num - 4 * N; i += 4 * N) { + Decompress2(d_raw, w, w_ofs + i + 0 * N, w0, w1); + Decompress2(d_raw, w, w_ofs + i + 2 * N, w2, w3); + Decompress2(d_raw, v, i + 0 * N, v0, v1); + Decompress2(d_raw, v, i + 2 * N, v2, v3); - kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, - comp0, comp1, comp2, comp3); + kernel.Update4(d_raw, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, + sum3, comp0, comp1, comp2, comp3); } } - size_t remaining = vec.num - i; + size_t remaining = v.num - i; HWY_DASSERT(remaining < 4 * N); if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)]; - HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; - DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining); - DecompressAndZeroPad(d, vec, i, padded_v, remaining); + HWY_ALIGN Raw padded_w[4 * hn::MaxLanes(d_raw)]; + HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)]; + DecompressAndZeroPad(d_raw, w, w_ofs + i, padded_w, remaining); + DecompressAndZeroPad(d_raw, v, i, padded_v, remaining); // 1..4 whole vectors, possibly zero-padded. for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { - const V w0 = hn::Load(d, padded_w + padded_pos); - const V v0 = hn::Load(d, padded_v + padded_pos); - kernel.Update1(d, w0, v0, sum0, comp0); + const VRaw w0 = hn::Load(d_raw, padded_w + padded_pos); + const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos); + kernel.Update1(d_raw, w0, v0, sum0, comp0); } } - return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); + return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2, + comp3); } // Same as above, but single input array. Used by RMSNorm. -template -HWY_INLINE float DecompressAndCall(D d, const PackedSpan vec, +template +HWY_INLINE float DecompressAndCall(D, const PackedSpan v, const Kernel& kernel) { // Decompressed inputs - using T = hn::TFromD; - using V = hn::Vec; - V v0, v1, v2, v3; + using Raw = typename Kernel::template Raw; + const hn::Repartition d_raw; + using VRaw = hn::Vec; + VRaw v0, v1, v2, v3; // State for Kernel - using StateT = hwy::If(), double, float>; - const hn::Repartition ds; - using VS = hn::Vec; - VS sum0 = hn::Zero(ds); - VS sum1 = hn::Zero(ds); - VS sum2 = hn::Zero(ds); - VS sum3 = hn::Zero(ds); - VS comp0 = hn::Zero(ds); - VS comp1 = hn::Zero(ds); - VS comp2 = hn::Zero(ds); - VS comp3 = hn::Zero(ds); + const hn::Repartition d_state; + using VState = hn::Vec; + VState sum0 = hn::Zero(d_state); + VState sum1 = hn::Zero(d_state); + VState sum2 = hn::Zero(d_state); + VState sum3 = hn::Zero(d_state); + VState comp0 = hn::Zero(d_state); + VState comp1 = hn::Zero(d_state); + VState comp2 = hn::Zero(d_state); + VState comp3 = hn::Zero(d_state); - const size_t N = hn::Lanes(d); + const size_t N = hn::Lanes(d_raw); size_t i = 0; - if (vec.num >= 4 * N) { - for (; i <= vec.num - 4 * N; i += 4 * N) { - Decompress2(d, vec, i + 0 * N, v0, v1); - Decompress2(d, vec, i + 2 * N, v2, v3); + if (v.num >= 4 * N) { + for (; i <= v.num - 4 * N; i += 4 * N) { + Decompress2(d_raw, v, i + 0 * N, v0, v1); + Decompress2(d_raw, v, i + 2 * N, v2, v3); - kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, - comp0, comp1, comp2, comp3); + kernel.Update4(d_raw, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, + sum3, comp0, comp1, comp2, comp3); } } - size_t remaining = vec.num - i; + size_t remaining = v.num - i; HWY_DASSERT(remaining < 4 * N); if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; - DecompressAndZeroPad(d, vec, i, padded_v, remaining); + HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)]; + DecompressAndZeroPad(d_raw, v, i, padded_v, remaining); // 1..4 whole vectors, possibly zero-padded. for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { - const V v0 = hn::Load(d, padded_v + padded_pos); - kernel.Update1(d, v0, v0, sum0, comp0); + const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos); + kernel.Update1(d_raw, v0, v0, sum0, comp0); } } - return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); + return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2, + comp3); } // Functor called for each tensor, which compresses and stores them along with diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 815bdbb..f8e0a4f 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -40,15 +40,20 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +// Our naming convention for dot product arguments is `w` and `v`, in that +// order. This originated in `MatVec`, which computed dot products of a +// compressed "weight" type, and `BF16/float` "vectors". This implementation no +// longer restricts the types of the arguments, but we keep the names for +// consistency, also because there is still a `w_ofs` but not a `v_ofs`. + //------------------------------------------------------------------------------ -// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. This is large when there are many -// similar-magnitude and opposite-sign elements. See -// https://en.wikipedia.org/wiki/Condition_number. -template -HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w, - const VecT* HWY_RESTRICT v, - size_t num) { +// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. The log2 of this value +// approximates the number of mantissa bits required for accurate computations. +// See https://en.wikipedia.org/wiki/Condition_number. +template +HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w, + const VT* HWY_RESTRICT v, size_t num) { PROFILER_FUNC; const hn::ScalableTag df; using VF = hn::Vec; @@ -104,9 +109,8 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w, } // Same, but for a single vector - just skips the product. -template -HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v, - size_t num) { +template +HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) { PROFILER_FUNC; const hn::ScalableTag df; using VF = hn::Vec; @@ -153,12 +157,60 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v, return cond; } -// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. 10 ops is too slow -// for compute-limited Matmul but might be OK for attention. -// Also supports bf16 inputs, used by matvec-inl.h. +// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed +// of f32 FMA. Only usable if `CanDecompressToDouble()`. +struct DotKernelDouble { + template + using Raw = double; + using State = double; + + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2, + const VR w3, const VR v0, const VR v1, const VR v2, + const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3, + VR&, VR&, VR&, VR&) const { + sum0 = hn::MulAdd(w0, v0, sum0); + sum1 = hn::MulAdd(w1, v1, sum1); + sum2 = hn::MulAdd(w2, v2, sum2); + sum3 = hn::MulAdd(w3, v3, sum3); + } + + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0, + VR&) const { + sum0 = hn::MulAdd(w0, v0, sum0); + } + + template , HWY_IF_F64_D(DState)> + HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS&, VS&, VS&, VS&) const { + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, sum2); + return static_cast(hn::ReduceSum(dd, sum0)); + } +}; + +template +HWY_INLINE float DotDouble(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { + return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble()); +} + +// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower +// than DotKernelDouble and about equally accurate. struct DotKernelCompensated { - template , HWY_IF_F32_D(DF)> - HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, + // Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h. + // Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than + // promoting to `float` because it does not call `TwoProducts`. + template + using Raw = hwy::If() || IsF32(), float, BF16>; + using State = float; + + // Raw = float + template , HWY_IF_F32_D(DRaw)> + HWY_INLINE void Update4(DRaw df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { @@ -180,20 +232,20 @@ struct DotKernelCompensated { comp3 = hn::Add(comp3, hn::Add(perr3, serr3)); } - template , HWY_IF_BF16_D(DBF), - class DF = hn::Repartition, class VF = hn::Vec> - HWY_INLINE void Update4(DBF /*dbf*/, const VBF w0, const VBF w1, const VBF w2, - const VBF w3, const VBF v0, const VBF v1, - const VBF v2, const VBF v3, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, - VF& comp3) const { - const DF df; - const VF prod0 = WidenMulPairwiseAdd(df, w0, v0); - const VF prod1 = WidenMulPairwiseAdd(df, w1, v1); - const VF prod2 = WidenMulPairwiseAdd(df, w2, v2); - const VF prod3 = WidenMulPairwiseAdd(df, w3, v3); + // Raw = BF16, State = float + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2, + const VR w3, const VR v0, const VR v1, const VR v2, + const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS& comp0, VS& comp1, VS& comp2, VS& comp3) const { + const DS df; + const VS prod0 = WidenMulPairwiseAdd(df, w0, v0); + const VS prod1 = WidenMulPairwiseAdd(df, w1, v1); + const VS prod2 = WidenMulPairwiseAdd(df, w2, v2); + const VS prod3 = WidenMulPairwiseAdd(df, w3, v3); - VF serr0, serr1, serr2, serr3; + VS serr0, serr1, serr2, serr3; sum0 = TwoSums(df, prod0, sum0, serr0); sum1 = TwoSums(df, prod1, sum1, serr1); sum2 = TwoSums(df, prod2, sum2, serr2); @@ -205,6 +257,7 @@ struct DotKernelCompensated { comp3 = hn::Add(comp3, serr3); } + // Raw = float template , HWY_IF_F32_D(DF)> HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0, VF& comp0) const { @@ -217,22 +270,23 @@ struct DotKernelCompensated { comp0 = hn::Add(comp0, hn::Add(perr0, serr0)); } - template , HWY_IF_BF16_D(DBF), - class DF = hn::Repartition, class VF = hn::Vec> - HWY_INLINE void Update1(DBF /*dbf*/, const VBF w0, const VBF v0, VF& sum0, - VF& comp0) const { - const DF df; - const VF prod0 = WidenMulPairwiseAdd(df, w0, v0); + // Raw = BF16, State = float + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0, + VS& comp0) const { + const DS df; + const VS prod0 = WidenMulPairwiseAdd(df, w0, v0); - VF serr0; + VS serr0; sum0 = TwoSums(df, prod0, sum0, serr0); comp0 = hn::Add(comp0, serr0); } - template > - HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3, - VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { + template > + HWY_INLINE float Reduce(DS df, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS& comp0, VS& comp1, VS& comp2, VS& comp3) const { // Reduction tree: sum of all accumulators by pairs, then across lanes. AssimilateCascadedSums(df, sum1, comp1, sum0, comp0); AssimilateCascadedSums(df, sum3, comp3, sum2, comp2); @@ -241,35 +295,38 @@ struct DotKernelCompensated { } }; -// Default kernel -template -HWY_INLINE float Dot(D d, const PackedSpan& w, size_t w_ofs, - const VecT* HWY_RESTRICT vec, size_t num) { +template +using DotKernelDefault = hwy::If(), + DotKernelDouble, DotKernelCompensated>; + +// `D` only serves to specify the vector size; its lane type is ignored. +template +HWY_INLINE float Dot(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), - DotKernelCompensated()); + DotKernelDefault()); } -// Adapter for a single pointer, no bounds checking. -template -HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec, - size_t num) { - const hn::ScalableTag d; +// Adapter for two pointers, no bounds checking. +template +HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) { + const hn::ScalableTag d; return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num); } // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. -template +template HWY_INLINE float Dot(const std::array& w, size_t w_ofs, - const VecT* vec, size_t num) { - const hn::ScalableTag d; + const VT* vec, size_t num) { + const hn::ScalableTag d; return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); } // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. -template +template HWY_INLINE float Dot(const CompressedArray& w, size_t w_ofs, - const VecT* vec, size_t num) { - const hn::ScalableTag d; + const VT* vec, size_t num) { + const hn::ScalableTag d; return w.scale() * Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); } diff --git a/ops/dot_test.cc b/ops/dot_test.cc index a615627..f39088b 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -120,6 +120,10 @@ void AssertLess(size_t variant, T actual, T max, int line) { // All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}. struct DotKernelNaive { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -150,51 +154,18 @@ struct DotKernelNaive { } }; -template -HWY_INLINE float DotNaive(D d, const PackedSpan& w, size_t w_ofs, - const VecT* HWY_RESTRICT vec, size_t num) { +template +HWY_INLINE float DotNaive(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive()); } -struct DotKernelDouble { - template , HWY_IF_F64_D(DD)> - HWY_INLINE void Update4(DD dd, const VD w0, const VD w1, const VD w2, - const VD w3, const VD v0, const VD v1, const VD v2, - const VD v3, VD& sum0, VD& sum1, VD& sum2, VD& sum3, - VD&, VD&, VD&, VD&) const { - sum0 = hn::MulAdd(w0, v0, sum0); - sum1 = hn::MulAdd(w1, v1, sum1); - sum2 = hn::MulAdd(w2, v2, sum2); - sum3 = hn::MulAdd(w3, v3, sum3); - } - - template , HWY_IF_F64_D(DD)> - HWY_INLINE void Update1(DD dd, const VD w0, const VD v0, VD& sum0, - VD&) const { - sum0 = hn::MulAdd(w0, v0, sum0); - } - - template , HWY_IF_F64_D(DD)> - HWY_INLINE float Reduce(DD dd, VD& sum0, VD& sum1, VD& sum2, VD& sum3, VD&, - VD&, VD&, VD&) const { - // Reduction tree: sum of all accumulators by pairs, then across lanes. - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - sum0 = hn::Add(sum0, sum2); - return static_cast(hn::ReduceSum(dd, sum0)); - } -}; - -template -HWY_INLINE float DotDouble(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, - size_t num) { - const hn::Repartition dd; - return DecompressAndCall(dd, w, w_ofs, MakeSpan(vec, num), DotKernelDouble()); -} - // https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum. struct DotKernelKahan { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -234,15 +205,15 @@ struct DotKernelKahan { } }; -template -HWY_INLINE float DotKahan(D d, const PackedSpan& w, size_t w_ofs, - const VecT* HWY_RESTRICT vec, size_t num) { +template +HWY_INLINE float DotKahan(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan()); } -template -HWY_INLINE float DotCompensated(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, +template +HWY_INLINE float DotCompensated(D d, const PackedSpan& w, + size_t w_ofs, const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelCompensated()); @@ -250,6 +221,10 @@ HWY_INLINE float DotCompensated(D d, const PackedSpan& w, // Like Compensated, but FastTwoSum instead of TwoSum. struct DotKernelTwoProdFast { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -296,9 +271,9 @@ struct DotKernelTwoProdFast { } }; -template -HWY_INLINE float DotTwoProdFast(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, +template +HWY_INLINE float DotTwoProdFast(D d, const PackedSpan& w, + size_t w_ofs, const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelTwoProdFast()); @@ -307,6 +282,10 @@ HWY_INLINE float DotTwoProdFast(D d, const PackedSpan& w, // Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums // to TwoSums. struct DotKernelMulTwoSum { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -351,10 +330,9 @@ struct DotKernelMulTwoSum { } }; -template -HWY_INLINE float DotMulTwoSum(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, - size_t num) { +template +HWY_INLINE float DotMulTwoSum(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelMulTwoSum()); } @@ -362,6 +340,10 @@ HWY_INLINE float DotMulTwoSum(D d, const PackedSpan& w, // -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10% // better (mul) than naive. struct DotKernelTwoProdAdd { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -406,10 +388,9 @@ struct DotKernelTwoProdAdd { } }; -template -HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, - size_t num) { +template +HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelTwoProdAdd()); } @@ -417,6 +398,10 @@ HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan& w, // From "SIMDizing Pairwise Sums". Slower and generally higher error than // Kahan, but uses fewer regs. struct DotKernelPairwise { + template + using Raw = float; + using State = float; + template > HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -472,10 +457,9 @@ struct DotKernelPairwise { mutable size_t num_ = 0; }; -template -HWY_INLINE float DotPairwise(D d, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT vec, - size_t num) { +template +HWY_INLINE float DotPairwise(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelPairwise()); } @@ -483,6 +467,10 @@ HWY_INLINE float DotPairwise(D d, const PackedSpan& w, // Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul // is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs. struct DotKernelComp2 { + template + using Raw = float; + using State = float; + template , HWY_IF_F32_D(DF)> HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, const VF v0, const VF v1, const VF v2, @@ -567,17 +555,17 @@ struct DotKernelComp2 { } }; -template -HWY_INLINE float DotComp2(D d, const PackedSpan& w, size_t w_ofs, - const VecT* HWY_RESTRICT vec, size_t num) { +template +HWY_INLINE float DotComp2(D d, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT vec, size_t num) { return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2()); } -template -float CallDot(D d, size_t variant, const PackedSpan& w, - size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) { +template +float CallDot(D d, size_t variant, const PackedSpan& w, size_t w_ofs, + const VT* HWY_RESTRICT v, size_t num) { // float inputs also support kDouble. - if constexpr (hwy::IsSame() && hwy::IsSame()) { + if constexpr (CanDecompressToDouble()) { if (variant == kDouble) return DotDouble(d, w, 0, v, num); } @@ -608,9 +596,9 @@ float CallDot(D d, size_t variant, const PackedSpan& w, // and round to nearest. See "Accurate and efficient floating point summation". // Much too slow to be useful. Kept separate from the above kernels because it // is used to compute their error. -template -float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v, - size_t num, double* HWY_RESTRICT buf) { +template +float ExactDot(const WT* HWY_RESTRICT w, const VT* HWY_RESTRICT v, size_t num, + double* HWY_RESTRICT buf) { PROFILER_FUNC; for (size_t i = 0; i < num; ++i) { buf[i] = @@ -944,18 +932,17 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, // Returns the actual condition number. Based on Algorithm 6.1 from "Accurate // Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf. -template -double GenerateIllConditionedInputs(const size_t num, WeightT* w, - VecT* HWY_RESTRICT v, std::mt19937& rng) { +template +double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v, + std::mt19937& rng) { PROFILER_FUNC; const size_t half = HWY_MAX(1, num / 2); // generate at least one random HWY_DASSERT(half != 0); const hn::ScalableTag df; + const PackedSpan w_span(w, num); - const PackedSpan w_span(w, num); - - // Regardless of WeightT and VecT, we will accumulate into float. Multiplying + // Regardless of WT and VT, we will accumulate into float. Multiplying // two maximal inputs and accumulating `num` times is enough for some loss of // precision and condition numbers between 1E6-1E9, which is what we see for // Attention Dot and `RMSNormMul`. @@ -966,8 +953,8 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w, for (size_t i = 0; i < half; ++i) { // Ensure the min and max exponents are used. const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng); - w[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); - v[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); + w[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); + v[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); } const float a_exp_step = @@ -976,16 +963,16 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w, for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) { const int e = static_cast(a_exp); HWY_DASSERT(e >= 0); - w[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); + w[i] = hwy::ConvertScalarTo(RandomFloat(rng) * (1 << e)); const float r = RandomFloat(rng) * (1 << e); if (hwy::ConvertScalarTo(w[i]) == 0.0f) { - v[i] = hwy::ConvertScalarTo(0.0f); + v[i] = hwy::ConvertScalarTo(0.0f); } else { // This is called >100K times. DotCompensated is much faster than ExactDot // and just about as accurate. const float exact = DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i); - v[i] = hwy::ConvertScalarTo( + v[i] = hwy::ConvertScalarTo( r - exact / hwy::ConvertScalarTo(w[i])); } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index eaf1870..ef4b48e 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -142,11 +142,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x, namespace detail { // Shared by RMSNorm and RMSNormInplace. -template -float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) { - const hn::ScalableTag df; +template +float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { + const hn::ScalableTag d; const float l2 = - DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated()); + DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); constexpr float kEps = 1e-6f; // avoid divide by zero return 1.0f / sqrtf(l2 / StaticCast(size) + kEps); } @@ -504,6 +504,40 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( MulByConstAndAdd(c, x, out, size, size); } +// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed +// of f32 sums. Only usable if `CanDecompressToDouble()`. +struct SumKernelDouble { + template + using Raw = double; + using State = double; + + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2, + const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1, + VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const { + sum0 = hn::Add(sum0, w0); + sum1 = hn::Add(sum1, w1); + sum2 = hn::Add(sum2, w2); + sum3 = hn::Add(sum3, w3); + } + + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0, + VR& comp0) const { + sum0 = hn::Add(sum0, w0); + } + + template > + HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS&, VS&, VS&, VS&) const { + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, sum2); + return static_cast(hn::ReduceSum(dd, sum0)); + } +}; + // ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point // Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums // instead of FastTwoSums because the magnitude of the initial sum is not @@ -511,7 +545,13 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( // generation results. Note that Kahan summation differs in that it first adds // comp* to w*, so each operation is serially dependent. By contrast, the sum* // and comp* here have shorter dependency chains. -struct KernelCascadedSum { +// +// This is slower than SumKernelDouble and about equally accurate. +struct SumKernelCascaded { + template + using Raw = float; + using State = float; + template , HWY_IF_F32_D(DF)> HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1, @@ -549,6 +589,17 @@ struct KernelCascadedSum { } }; +template +using SumKernelDefault = hwy::If(), + SumKernelDouble, SumKernelCascaded>; + +template +HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { + using Raw = hwy::If; + const hn::Repartition d_raw; + return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); +} + static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const size_t mask_pos) { PROFILER_FUNC; @@ -582,8 +633,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // not make a huge difference. It halves the standard deviation of the sum of // the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the generated text after a few hundred tokens. - const float sum_exp = - DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum()); + const float sum_exp = Sum(d, x, mask_pos); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; MulByConst(mul, x, size, mask_pos);