Also enable f64 dot/sum for <f32 inputs

Add bf16 support to Dot/SumKernelDouble in the same way as *Compensated.

PiperOrigin-RevId: 682308683
This commit is contained in:
Jan Wassenberg 2024-10-04 07:11:27 -07:00 committed by Copybara-Service
parent 895ee4c6ce
commit 5a71d819cb
5 changed files with 183 additions and 46 deletions

View File

@ -498,14 +498,6 @@ constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>(); return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
} }
// TODO: `DotKernelDouble` requires both inputs to be `float` because currently
// only `CompressTraits<float>` can `Decompress2` to `double`. It is not yet
// clear whether we want to implement this for other Packed types.
template <typename WT, typename VT>
constexpr bool CanDecompressToDouble() {
return HWY_HAVE_FLOAT64 && IsF32<WT>() && IsF32<VT>();
}
namespace detail { namespace detail {
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes // Compile-time-only check that `DRaw` and `Packed` are compatible. This makes

View File

@ -157,13 +157,17 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
return cond; return cond;
} }
// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed // f64 FMA. Inputs are both f32 promoted to f64, or any types that are either
// of f32 FMA. Only usable if `CanDecompressToDouble<WT, VT>()`. // promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA.
struct DotKernelDouble { struct DotKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
template <typename VT, typename WT> template <typename VT, typename WT>
using Raw = double; using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using State = double; using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2, HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2, const VR w3, const VR v0, const VR v1, const VR v2,
@ -175,18 +179,77 @@ struct DotKernelDouble {
sum3 = hn::MulAdd(w3, v3, sum3); sum3 = hn::MulAdd(w3, v3, sum3);
} }
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
// Reduce to two f32 sums so we can promote them to four f64 vectors.
VF sum02, sum13;
if constexpr (HWY_NATIVE_DOT_BF16) {
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
VF unused0 = hn::Zero(df);
VF unused1 = hn::Zero(df);
sum02 = hn::ReorderWidenMulAccumulate(df, w2, v2, prod0, unused0);
sum13 = hn::ReorderWidenMulAccumulate(df, w3, v3, prod1, unused1);
} else {
// ReorderWidenMulAccumulate does not help because we still end up with
// four accumulators.
const VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
const VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
sum02 = hn::Add(prod0, prod2);
sum13 = hn::Add(prod1, prod3);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, sum02);
const VS d1 = hn::PromoteUpperTo(ds, sum02);
const VS d2 = hn::PromoteLowerTo(ds, sum13);
const VS d3 = hn::PromoteUpperTo(ds, sum13);
sum0 = hn::Add(sum0, d0);
sum1 = hn::Add(sum1, d1);
sum2 = hn::Add(sum2, d2);
sum3 = hn::Add(sum3, d3);
}
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0, HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
VR&) const { VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0); sum0 = hn::MulAdd(w0, v0, sum0);
} }
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
VS& extra0) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, prod0);
const VS d1 = hn::PromoteUpperTo(ds, prod0);
sum0 = hn::Add(sum0, d0);
extra0 = hn::Add(extra0, d1);
}
template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)> template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const { VS& extra0, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes. // Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1); sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3); sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, extra0); // from Update1
sum0 = hn::Add(sum0, sum2); sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0)); return static_cast<float>(hn::ReduceSum(dd, sum0));
} }
@ -198,12 +261,13 @@ HWY_INLINE float DotDouble(D d, const PackedSpan<const WT>& w, size_t w_ofs,
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble()); return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
} }
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower // Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This about as
// than DotKernelDouble and about equally accurate. // accurate as DotKernelDouble but slower, hence we only use this if f64 is
// not supported on this target.
struct DotKernelCompensated { struct DotKernelCompensated {
// Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h. // The `BF16` overload uses `ReorderWidenMulAccumulate`, which requires both
// Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than // `VT` and `WT` to be `BF16`, or smaller types decompressed to `BF16`.
// promoting to `float` because it does not call `TwoProducts`. // Otherwise, we decompress both inputs to `float`.
template <typename VT, typename WT> template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>; using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>;
using State = float; using State = float;
@ -240,10 +304,10 @@ struct DotKernelCompensated {
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const { VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
const DS df; const DS df;
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0); const VS prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
const VS prod1 = WidenMulPairwiseAdd(df, w1, v1); const VS prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
const VS prod2 = WidenMulPairwiseAdd(df, w2, v2); const VS prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
const VS prod3 = WidenMulPairwiseAdd(df, w3, v3); const VS prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
VS serr0, serr1, serr2, serr3; VS serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0); sum0 = TwoSums(df, prod0, sum0, serr0);
@ -295,16 +359,14 @@ struct DotKernelCompensated {
} }
}; };
template <typename WT, typename VT> using DotKernelDefault =
using DotKernelDefault = hwy::If<CanDecompressToDouble<WT, VT>(), hwy::If<HWY_HAVE_FLOAT64, DotKernelDouble, DotKernelCompensated>;
DotKernelDouble, DotKernelCompensated>;
// `D` only serves to specify the vector size; its lane type is ignored. // `D` only serves to specify the vector size; its lane type is ignored.
template <class D, typename WT, typename VT> template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs, HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) { const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
DotKernelDefault<WT, VT>());
} }
// Adapter for two pointers, no bounds checking. // Adapter for two pointers, no bounds checking.

View File

@ -507,10 +507,10 @@ struct DotKernelComp2 {
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const { VF& comp3) const {
const DF df; const DF df;
VF prod0 = WidenMulPairwiseAdd(df, w0, v0); VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
VF prod1 = WidenMulPairwiseAdd(df, w1, v1); VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
VF prod2 = WidenMulPairwiseAdd(df, w2, v2); VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
VF prod3 = WidenMulPairwiseAdd(df, w3, v3); VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
// Pairwise sums // Pairwise sums
prod0 = hn::Add(prod0, prod1); prod0 = hn::Add(prod0, prod1);
@ -564,11 +564,6 @@ HWY_INLINE float DotComp2(D d, const PackedSpan<const WT>& w, size_t w_ofs,
template <class D, typename WT, typename VT, HWY_IF_F32_D(D)> template <class D, typename WT, typename VT, HWY_IF_F32_D(D)>
float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs, float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT v, size_t num) { const VT* HWY_RESTRICT v, size_t num) {
// float inputs also support kDouble.
if constexpr (CanDecompressToDouble<WT, VT>()) {
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
}
switch (variant) { switch (variant) {
case kAddTwoProd: case kAddTwoProd:
return DotTwoProdFast(d, w, 0, v, num); return DotTwoProdFast(d, w, 0, v, num);
@ -578,6 +573,12 @@ float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
return DotComp2(d, w, 0, v, num); return DotComp2(d, w, 0, v, num);
case kCompensated: case kCompensated:
return DotCompensated(d, w, 0, v, num); return DotCompensated(d, w, 0, v, num);
case kDouble:
if constexpr (HWY_HAVE_FLOAT64) {
return DotDouble(d, w, 0, v, num);
} else {
return DotCompensated(d, w, 0, v, num);
}
case kKahan: case kKahan:
return DotKahan(d, w, 0, v, num); return DotKahan(d, w, 0, v, num);
case kNaive: case kNaive:

View File

@ -142,7 +142,9 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>()); // Dot(float,BF16) rounds both to BF16.
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
const double tolerance = 200.0 * norm * epsilon; const double tolerance = 200.0 * norm * epsilon;
for (size_t idx = 0; idx < num_c; idx++) { for (size_t idx = 0; idx < num_c; idx++) {

View File

@ -146,8 +146,7 @@ namespace detail {
template <typename VT> template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
const float l2 = const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault<VT, VT>());
constexpr float kEps = 1e-6f; // avoid divide by zero constexpr float kEps = 1e-6f; // avoid divide by zero
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps); return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
} }
@ -506,12 +505,16 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
} }
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed // f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums. Only usable if `CanDecompressToDouble<VT, VT>()`. // of f32 sums.
struct SumKernelDouble { struct SumKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
template <typename VT, typename WT> template <typename VT, typename WT>
using Raw = double; using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using State = double; using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2, HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1, const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
@ -522,18 +525,95 @@ struct SumKernelDouble {
sum3 = hn::Add(sum3, w3); sum3 = hn::Add(sum3, w3);
} }
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
// Reduce to two f32 sums so we can promote them to four f64 vectors.
VF sum02, sum13;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
VF unused0 = hn::Zero(df);
VF unused1 = hn::Zero(df);
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
} else {
// If not native, the multiplication costs extra, so convert to f32.
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fe1 = hn::PromoteEvenTo(df, w1);
const VF fe2 = hn::PromoteEvenTo(df, w2);
const VF fe3 = hn::PromoteEvenTo(df, w3);
const VF fo0 = hn::PromoteOddTo(df, w0);
const VF fo1 = hn::PromoteOddTo(df, w1);
const VF fo2 = hn::PromoteOddTo(df, w2);
const VF fo3 = hn::PromoteOddTo(df, w3);
const VF fe01 = hn::Add(fe0, fe1);
const VF fe23 = hn::Add(fe2, fe3);
const VF fo01 = hn::Add(fo0, fo1);
const VF fo23 = hn::Add(fo2, fo3);
sum02 = hn::Add(fe01, fe23);
sum13 = hn::Add(fo01, fo23);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, sum02);
const VS d1 = hn::PromoteUpperTo(ds, sum02);
const VS d2 = hn::PromoteLowerTo(ds, sum13);
const VS d3 = hn::PromoteUpperTo(ds, sum13);
sum0 = hn::Add(sum0, d0);
sum1 = hn::Add(sum1, d1);
sum2 = hn::Add(sum2, d2);
sum3 = hn::Add(sum3, d3);
}
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0, HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const { VR& comp0) const {
sum0 = hn::Add(sum0, w0); sum0 = hn::Add(sum0, w0);
} }
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
VS& extra0) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
VF f0;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
} else {
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fo0 = hn::PromoteOddTo(df, w0);
f0 = hn::Add(fe0, fo0);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, f0);
const VS d1 = hn::PromoteUpperTo(ds, f0);
sum0 = hn::Add(sum0, d0);
extra0 = hn::Add(extra0, d1);
}
template <class DState, class VS = hn::Vec<DState>> template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const { VS& extra0, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes. // Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1); sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3); sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, extra0); // from Update1
sum0 = hn::Add(sum0, sum2); sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0)); return static_cast<float>(hn::ReduceSum(dd, sum0));
} }
@ -547,7 +627,8 @@ struct SumKernelDouble {
// comp* to w*, so each operation is serially dependent. By contrast, the sum* // comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains. // and comp* here have shorter dependency chains.
// //
// This is slower than SumKernelDouble and about equally accurate. // This about as accurate as SumKernelDouble but slower, hence we only use this
// if f64 is not supported on this target.
struct SumKernelCascaded { struct SumKernelCascaded {
template <typename VT, typename WT> template <typename VT, typename WT>
using Raw = float; using Raw = float;
@ -590,15 +671,14 @@ struct SumKernelCascaded {
} }
}; };
template <typename VT> using SumKernelDefault =
using SumKernelDefault = hwy::If<CanDecompressToDouble<VT, VT>(), hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
SumKernelDouble, SumKernelCascaded>;
template <class D, typename VT> template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>; using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw; const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>()); return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
} }
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.