mirror of https://github.com/google/gemma.cpp.git
Also enable f64 dot/sum for <f32 inputs
Add bf16 support to Dot/SumKernelDouble in the same way as *Compensated. PiperOrigin-RevId: 682308683
This commit is contained in:
parent
895ee4c6ce
commit
5a71d819cb
|
|
@ -498,14 +498,6 @@ constexpr bool IsF32() {
|
|||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||
}
|
||||
|
||||
// TODO: `DotKernelDouble` requires both inputs to be `float` because currently
|
||||
// only `CompressTraits<float>` can `Decompress2` to `double`. It is not yet
|
||||
// clear whether we want to implement this for other Packed types.
|
||||
template <typename WT, typename VT>
|
||||
constexpr bool CanDecompressToDouble() {
|
||||
return HWY_HAVE_FLOAT64 && IsF32<WT>() && IsF32<VT>();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
||||
|
|
|
|||
|
|
@ -157,13 +157,17 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
|
|||
return cond;
|
||||
}
|
||||
|
||||
// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed
|
||||
// of f32 FMA. Only usable if `CanDecompressToDouble<WT, VT>()`.
|
||||
// f64 FMA. Inputs are both f32 promoted to f64, or any types that are either
|
||||
// promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA.
|
||||
struct DotKernelDouble {
|
||||
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
|
||||
// to be `float` in order to have `Raw = double`. Note that if either type is
|
||||
// smaller than `float`, we may demote the other type from `float` to `BF16`.
|
||||
template <typename VT, typename WT>
|
||||
using Raw = double;
|
||||
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
|
||||
using State = double;
|
||||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||
|
|
@ -175,18 +179,77 @@ struct DotKernelDouble {
|
|||
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||
VS&, VS&, VS&, VS&) const {
|
||||
const hn::Repartition<float, DRaw> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
|
||||
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
|
||||
// Reduce to two f32 sums so we can promote them to four f64 vectors.
|
||||
VF sum02, sum13;
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
|
||||
VF unused0 = hn::Zero(df);
|
||||
VF unused1 = hn::Zero(df);
|
||||
sum02 = hn::ReorderWidenMulAccumulate(df, w2, v2, prod0, unused0);
|
||||
sum13 = hn::ReorderWidenMulAccumulate(df, w3, v3, prod1, unused1);
|
||||
} else {
|
||||
// ReorderWidenMulAccumulate does not help because we still end up with
|
||||
// four accumulators.
|
||||
const VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
|
||||
const VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
|
||||
sum02 = hn::Add(prod0, prod2);
|
||||
sum13 = hn::Add(prod1, prod3);
|
||||
}
|
||||
|
||||
const DS ds;
|
||||
const VS d0 = hn::PromoteLowerTo(ds, sum02);
|
||||
const VS d1 = hn::PromoteUpperTo(ds, sum02);
|
||||
const VS d2 = hn::PromoteLowerTo(ds, sum13);
|
||||
const VS d3 = hn::PromoteUpperTo(ds, sum13);
|
||||
|
||||
sum0 = hn::Add(sum0, d0);
|
||||
sum1 = hn::Add(sum1, d1);
|
||||
sum2 = hn::Add(sum2, d2);
|
||||
sum3 = hn::Add(sum3, d3);
|
||||
}
|
||||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
|
||||
VR&) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
|
||||
VS& extra0) const {
|
||||
const hn::Repartition<float, DRaw> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
|
||||
|
||||
const DS ds;
|
||||
const VS d0 = hn::PromoteLowerTo(ds, prod0);
|
||||
const VS d1 = hn::PromoteUpperTo(ds, prod0);
|
||||
|
||||
sum0 = hn::Add(sum0, d0);
|
||||
extra0 = hn::Add(extra0, d1);
|
||||
}
|
||||
|
||||
template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)>
|
||||
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||
VS&, VS&, VS&, VS&) const {
|
||||
VS& extra0, VS&, VS&, VS&) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, extra0); // from Update1
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||
}
|
||||
|
|
@ -198,12 +261,13 @@ HWY_INLINE float DotDouble(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
|||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
|
||||
}
|
||||
|
||||
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower
|
||||
// than DotKernelDouble and about equally accurate.
|
||||
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This about as
|
||||
// accurate as DotKernelDouble but slower, hence we only use this if f64 is
|
||||
// not supported on this target.
|
||||
struct DotKernelCompensated {
|
||||
// Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h.
|
||||
// Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than
|
||||
// promoting to `float` because it does not call `TwoProducts`.
|
||||
// The `BF16` overload uses `ReorderWidenMulAccumulate`, which requires both
|
||||
// `VT` and `WT` to be `BF16`, or smaller types decompressed to `BF16`.
|
||||
// Otherwise, we decompress both inputs to `float`.
|
||||
template <typename VT, typename WT>
|
||||
using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>;
|
||||
using State = float;
|
||||
|
|
@ -240,10 +304,10 @@ struct DotKernelCompensated {
|
|||
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
|
||||
const DS df;
|
||||
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||
const VS prod1 = WidenMulPairwiseAdd(df, w1, v1);
|
||||
const VS prod2 = WidenMulPairwiseAdd(df, w2, v2);
|
||||
const VS prod3 = WidenMulPairwiseAdd(df, w3, v3);
|
||||
const VS prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
|
||||
const VS prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
|
||||
const VS prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
|
||||
const VS prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
|
||||
|
||||
VS serr0, serr1, serr2, serr3;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
|
|
@ -295,16 +359,14 @@ struct DotKernelCompensated {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename WT, typename VT>
|
||||
using DotKernelDefault = hwy::If<CanDecompressToDouble<WT, VT>(),
|
||||
DotKernelDouble, DotKernelCompensated>;
|
||||
using DotKernelDefault =
|
||||
hwy::If<HWY_HAVE_FLOAT64, DotKernelDouble, DotKernelCompensated>;
|
||||
|
||||
// `D` only serves to specify the vector size; its lane type is ignored.
|
||||
template <class D, typename WT, typename VT>
|
||||
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||
const VT* HWY_RESTRICT vec, size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||
DotKernelDefault<WT, VT>());
|
||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
|
||||
}
|
||||
|
||||
// Adapter for two pointers, no bounds checking.
|
||||
|
|
|
|||
|
|
@ -507,10 +507,10 @@ struct DotKernelComp2 {
|
|||
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
||||
VF& comp3) const {
|
||||
const DF df;
|
||||
VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||
VF prod1 = WidenMulPairwiseAdd(df, w1, v1);
|
||||
VF prod2 = WidenMulPairwiseAdd(df, w2, v2);
|
||||
VF prod3 = WidenMulPairwiseAdd(df, w3, v3);
|
||||
VF prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);
|
||||
VF prod1 = hn::WidenMulPairwiseAdd(df, w1, v1);
|
||||
VF prod2 = hn::WidenMulPairwiseAdd(df, w2, v2);
|
||||
VF prod3 = hn::WidenMulPairwiseAdd(df, w3, v3);
|
||||
|
||||
// Pairwise sums
|
||||
prod0 = hn::Add(prod0, prod1);
|
||||
|
|
@ -564,11 +564,6 @@ HWY_INLINE float DotComp2(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
|||
template <class D, typename WT, typename VT, HWY_IF_F32_D(D)>
|
||||
float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||
const VT* HWY_RESTRICT v, size_t num) {
|
||||
// float inputs also support kDouble.
|
||||
if constexpr (CanDecompressToDouble<WT, VT>()) {
|
||||
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
|
||||
}
|
||||
|
||||
switch (variant) {
|
||||
case kAddTwoProd:
|
||||
return DotTwoProdFast(d, w, 0, v, num);
|
||||
|
|
@ -578,6 +573,12 @@ float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
|
|||
return DotComp2(d, w, 0, v, num);
|
||||
case kCompensated:
|
||||
return DotCompensated(d, w, 0, v, num);
|
||||
case kDouble:
|
||||
if constexpr (HWY_HAVE_FLOAT64) {
|
||||
return DotDouble(d, w, 0, v, num);
|
||||
} else {
|
||||
return DotCompensated(d, w, 0, v, num);
|
||||
}
|
||||
case kKahan:
|
||||
return DotKahan(d, w, 0, v, num);
|
||||
case kNaive:
|
||||
|
|
|
|||
|
|
@ -142,7 +142,9 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
|||
|
||||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
// Dot(float,BF16) rounds both to BF16.
|
||||
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
|
||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
||||
const double tolerance = 200.0 * norm * epsilon;
|
||||
|
||||
for (size_t idx = 0; idx < num_c; idx++) {
|
||||
|
|
|
|||
100
ops/ops-inl.h
100
ops/ops-inl.h
|
|
@ -146,8 +146,7 @@ namespace detail {
|
|||
template <typename VT>
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
||||
const hn::ScalableTag<float> d;
|
||||
const float l2 =
|
||||
DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault<VT, VT>());
|
||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
||||
}
|
||||
|
|
@ -506,12 +505,16 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
}
|
||||
|
||||
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
|
||||
// of f32 sums. Only usable if `CanDecompressToDouble<VT, VT>()`.
|
||||
// of f32 sums.
|
||||
struct SumKernelDouble {
|
||||
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
|
||||
// to be `float` in order to have `Raw = double`. Note that if either type is
|
||||
// smaller than `float`, we may demote the other type from `float` to `BF16`.
|
||||
template <typename VT, typename WT>
|
||||
using Raw = double;
|
||||
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
|
||||
using State = double;
|
||||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
|
||||
|
|
@ -522,18 +525,95 @@ struct SumKernelDouble {
|
|||
sum3 = hn::Add(sum3, w3);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
|
||||
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
|
||||
const hn::Repartition<float, DRaw> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
// Reduce to two f32 sums so we can promote them to four f64 vectors.
|
||||
VF sum02, sum13;
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
||||
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
||||
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
|
||||
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
|
||||
VF unused0 = hn::Zero(df);
|
||||
VF unused1 = hn::Zero(df);
|
||||
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
|
||||
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
|
||||
} else {
|
||||
// If not native, the multiplication costs extra, so convert to f32.
|
||||
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
|
||||
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
||||
const VF fe1 = hn::PromoteEvenTo(df, w1);
|
||||
const VF fe2 = hn::PromoteEvenTo(df, w2);
|
||||
const VF fe3 = hn::PromoteEvenTo(df, w3);
|
||||
const VF fo0 = hn::PromoteOddTo(df, w0);
|
||||
const VF fo1 = hn::PromoteOddTo(df, w1);
|
||||
const VF fo2 = hn::PromoteOddTo(df, w2);
|
||||
const VF fo3 = hn::PromoteOddTo(df, w3);
|
||||
const VF fe01 = hn::Add(fe0, fe1);
|
||||
const VF fe23 = hn::Add(fe2, fe3);
|
||||
const VF fo01 = hn::Add(fo0, fo1);
|
||||
const VF fo23 = hn::Add(fo2, fo3);
|
||||
sum02 = hn::Add(fe01, fe23);
|
||||
sum13 = hn::Add(fo01, fo23);
|
||||
}
|
||||
|
||||
const DS ds;
|
||||
const VS d0 = hn::PromoteLowerTo(ds, sum02);
|
||||
const VS d1 = hn::PromoteUpperTo(ds, sum02);
|
||||
const VS d2 = hn::PromoteLowerTo(ds, sum13);
|
||||
const VS d3 = hn::PromoteUpperTo(ds, sum13);
|
||||
|
||||
sum0 = hn::Add(sum0, d0);
|
||||
sum1 = hn::Add(sum1, d1);
|
||||
sum2 = hn::Add(sum2, d2);
|
||||
sum3 = hn::Add(sum3, d3);
|
||||
}
|
||||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
|
||||
VR& comp0) const {
|
||||
sum0 = hn::Add(sum0, w0);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
|
||||
VS& extra0) const {
|
||||
const hn::Repartition<float, DRaw> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF f0;
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
||||
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
||||
} else {
|
||||
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
||||
const VF fo0 = hn::PromoteOddTo(df, w0);
|
||||
f0 = hn::Add(fe0, fo0);
|
||||
}
|
||||
|
||||
const DS ds;
|
||||
const VS d0 = hn::PromoteLowerTo(ds, f0);
|
||||
const VS d1 = hn::PromoteUpperTo(ds, f0);
|
||||
|
||||
sum0 = hn::Add(sum0, d0);
|
||||
extra0 = hn::Add(extra0, d1);
|
||||
}
|
||||
|
||||
template <class DState, class VS = hn::Vec<DState>>
|
||||
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||
VS&, VS&, VS&, VS&) const {
|
||||
VS& extra0, VS&, VS&, VS&) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, extra0); // from Update1
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||
}
|
||||
|
|
@ -547,7 +627,8 @@ struct SumKernelDouble {
|
|||
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
||||
// and comp* here have shorter dependency chains.
|
||||
//
|
||||
// This is slower than SumKernelDouble and about equally accurate.
|
||||
// This about as accurate as SumKernelDouble but slower, hence we only use this
|
||||
// if f64 is not supported on this target.
|
||||
struct SumKernelCascaded {
|
||||
template <typename VT, typename WT>
|
||||
using Raw = float;
|
||||
|
|
@ -590,15 +671,14 @@ struct SumKernelCascaded {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename VT>
|
||||
using SumKernelDefault = hwy::If<CanDecompressToDouble<VT, VT>(),
|
||||
SumKernelDouble, SumKernelCascaded>;
|
||||
using SumKernelDefault =
|
||||
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
|
||||
|
||||
template <class D, typename VT>
|
||||
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
||||
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
|
||||
const hn::Repartition<Raw, D> d_raw;
|
||||
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>());
|
||||
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
|
|
|
|||
Loading…
Reference in New Issue