Use f64 Dot and sum in softmax - faster than Cascaded

Also let the kernel specify the Raw and State types,
rename WeightT/VecT -> WT/VT.

PiperOrigin-RevId: 680464427
This commit is contained in:
Jan Wassenberg 2024-09-30 01:21:37 -07:00 committed by Copybara-Service
parent 47eb80a90e
commit 5e812f07f5
4 changed files with 342 additions and 232 deletions

View File

@ -120,7 +120,6 @@ struct CompressTraits<float> {
BF16* HWY_RESTRICT raw, size_t num) { BF16* HWY_RESTRICT raw, size_t num) {
const hn::Repartition<float, decltype(dbf)> df; const hn::Repartition<float, decltype(dbf)> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
size_t i = 0; size_t i = 0;
@ -169,7 +168,6 @@ struct CompressTraits<float> {
double* HWY_RESTRICT raw, size_t num) { double* HWY_RESTRICT raw, size_t num) {
const hn::Rebind<float, DD> df; const hn::Rebind<float, DD> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
using VD = hn::Vec<decltype(dd)>;
const size_t ND = hn::Lanes(dd); const size_t ND = hn::Lanes(dd);
size_t i = 0; size_t i = 0;
@ -413,8 +411,6 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed, static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0, const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) { hn::Vec<D>& raw1) {
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
using V8 = hn::Vec<decltype(d8)>;
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
} }
@ -497,23 +493,39 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
Traits::Store2(df, raw0, raw1, packed, packed_ofs); Traits::Store2(df, raw0, raw1, packed, packed_ofs);
} }
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
// TODO: `DotKernelDouble` requires both inputs to be `float` because currently
// only `CompressTraits<float>` can `Decompress2` to `double`. It is not yet
// clear whether we want to implement this for other Packed types.
template <typename WT, typename VT>
constexpr bool CanDecompressToDouble() {
return HWY_HAVE_FLOAT64 && IsF32<WT>() && IsF32<VT>();
}
namespace detail {
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes // Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
// for better error messages than "no matching function found". // for better error messages than "no matching function found".
template <class DRaw, typename Packed> template <class DRaw, typename Packed>
HWY_INLINE void VerifyRawAndPacked() { HWY_INLINE void VerifyRawAndPackedForDecompress() {
using TRaw = hn::TFromD<DRaw>; using TRaw = hn::TFromD<DRaw>;
constexpr bool kPackedF32 = hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
// We can decompress any Packed to f32 or BF16, or f32 to f64. // We can decompress any Packed to f32 or BF16, or f32 to f64.
static_assert(hwy::IsSameEither<TRaw, float, BF16>() || static_assert(hwy::IsSameEither<TRaw, float, BF16>() ||
(kPackedF32 && hwy::IsSame<TRaw, double>())); (IsF32<Packed>() && hwy::IsSame<TRaw, double>()));
} }
} // namespace detail
// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or // Decompresses from any type of `packed`, to two vectors of `float/BF16`, or
// `double`, if `Packed` is `float`. // `double`, if `Packed` is `float`.
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>> template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed, HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) { const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
VerifyRawAndPacked<DRaw, Packed>(); detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d)); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>; using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1); Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
@ -529,135 +541,139 @@ template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed, HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
const size_t packed_ofs, TRaw* raw, const size_t packed_ofs, TRaw* raw,
size_t num) { size_t num) {
VerifyRawAndPacked<DRaw, Packed>(); detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
packed.BoundsCheck(packed_ofs, num); packed.BoundsCheck(packed_ofs, num);
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>; using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
} }
// Decompresses to the type specified by `D` from each of two arrays in groups // Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector // both into groups of four vectors with lane type `Kernel::Raw`, passes them to
// multiple, then calls `kernel.Update1` for the remaining vectors. Returns // `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
// `kernel.Reduce`. // them to `kernel.Update1`, then returns `kernel.Reduce`. `v.num` is not
// required to be a multiple of the vector length.
// //
// This is useful for implementing dot products, and similar to // Both `w` and `v` can be any packed type. To support random access in `w`
// `hwy/contrib/unroller`, but also supports compressed types with simpler // even if it is `NuqStream`, we ignore `w.num` and provide a `w_ofs`, but no
// remainder handling thanks to `DecompressAndZeroPad`. // `v_ofs` because it is always 0 in our use cases. `D` only serves to specify
// // the vector size/fraction.
// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float.
// `w` can be any packed type, including NUQ, which requires a separate `w_ofs`
// rather than pointer arithmetic. `vec` can also be any type, but typically
// float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
// `num`, the number of elements to process, need not be a vector multiple.
// //
// `kernel` is const& so we can pass an rvalue argument, but can contain // `kernel` is const& so we can pass an rvalue argument, but can contain
// mutable state, though not vectors (see highway.h). We pass in the four // mutable state, though not vectors (see highway.h). In addition to the groups
// loaded vectors plus eight state vectors. The state vectors' lane type is // of four input vectors, we pass eight state vectors with lane type specified
// either `double` (required for DotKernelDouble) or `float`. // by `Kernel::State`, which is typically `float` but may differ if `Raw` is
template <class D, typename WeightT, typename VecT, class Kernel> // `double`, or `WT` and `VT` are `BF16`.
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w, //
// Decoupling decompression and remainder handling from the actual usage of the
// vectors makes it easier to implement various dot product and sum algorithms.
// This is similar to `hwy/contrib/unroller`, but less general and relies on
// `DecompressAndZeroPad`.
template <class D, typename WT, typename VT, class Kernel>
HWY_INLINE float DecompressAndCall(D, const PackedSpan<const WT>& w,
const size_t w_ofs, const size_t w_ofs,
const PackedSpan<const VecT> vec, const PackedSpan<const VT> v,
const Kernel& kernel) { const Kernel& kernel) {
// Decompressed inputs // Decompressed inputs
using T = hn::TFromD<D>; using Raw = typename Kernel::template Raw<WT, VT>;
using V = hn::Vec<decltype(d)>; const hn::Repartition<Raw, D> d_raw;
V w0, w1, w2, w3, v0, v1, v2, v3; using VRaw = hn::Vec<decltype(d_raw)>;
VRaw w0, w1, w2, w3, v0, v1, v2, v3;
// State for Kernel // State for Kernel
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>; const hn::Repartition<typename Kernel::State, D> d_state;
const hn::Repartition<StateT, D> ds; using VState = hn::Vec<decltype(d_state)>;
using VS = hn::Vec<decltype(ds)>; VState sum0 = hn::Zero(d_state);
VS sum0 = hn::Zero(ds); VState sum1 = hn::Zero(d_state);
VS sum1 = hn::Zero(ds); VState sum2 = hn::Zero(d_state);
VS sum2 = hn::Zero(ds); VState sum3 = hn::Zero(d_state);
VS sum3 = hn::Zero(ds); VState comp0 = hn::Zero(d_state);
VS comp0 = hn::Zero(ds); VState comp1 = hn::Zero(d_state);
VS comp1 = hn::Zero(ds); VState comp2 = hn::Zero(d_state);
VS comp2 = hn::Zero(ds); VState comp3 = hn::Zero(d_state);
VS comp3 = hn::Zero(ds);
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d_raw);
size_t i = 0; size_t i = 0;
if (vec.num >= 4 * N) { if (v.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) { for (; i <= v.num - 4 * N; i += 4 * N) {
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1); Decompress2(d_raw, w, w_ofs + i + 0 * N, w0, w1);
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3); Decompress2(d_raw, w, w_ofs + i + 2 * N, w2, w3);
Decompress2(d, vec, i + 0 * N, v0, v1); Decompress2(d_raw, v, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3); Decompress2(d_raw, v, i + 2 * N, v2, v3);
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, kernel.Update4(d_raw, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2,
comp0, comp1, comp2, comp3); sum3, comp0, comp1, comp2, comp3);
} }
} }
size_t remaining = vec.num - i; size_t remaining = v.num - i;
HWY_DASSERT(remaining < 4 * N); HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)]; HWY_ALIGN Raw padded_w[4 * hn::MaxLanes(d_raw)];
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining); DecompressAndZeroPad(d_raw, w, w_ofs + i, padded_w, remaining);
DecompressAndZeroPad(d, vec, i, padded_v, remaining); DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded. // 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
const V w0 = hn::Load(d, padded_w + padded_pos); const VRaw w0 = hn::Load(d_raw, padded_w + padded_pos);
const V v0 = hn::Load(d, padded_v + padded_pos); const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
kernel.Update1(d, w0, v0, sum0, comp0); kernel.Update1(d_raw, w0, v0, sum0, comp0);
} }
} }
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
comp3);
} }
// Same as above, but single input array. Used by RMSNorm. // Same as above, but single input array. Used by RMSNorm.
template <class D, typename VecT, class Kernel> template <class D, typename VT, class Kernel>
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec, HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
const Kernel& kernel) { const Kernel& kernel) {
// Decompressed inputs // Decompressed inputs
using T = hn::TFromD<D>; using Raw = typename Kernel::template Raw<VT, VT>;
using V = hn::Vec<decltype(d)>; const hn::Repartition<Raw, D> d_raw;
V v0, v1, v2, v3; using VRaw = hn::Vec<decltype(d_raw)>;
VRaw v0, v1, v2, v3;
// State for Kernel // State for Kernel
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>; const hn::Repartition<typename Kernel::State, D> d_state;
const hn::Repartition<StateT, D> ds; using VState = hn::Vec<decltype(d_state)>;
using VS = hn::Vec<decltype(ds)>; VState sum0 = hn::Zero(d_state);
VS sum0 = hn::Zero(ds); VState sum1 = hn::Zero(d_state);
VS sum1 = hn::Zero(ds); VState sum2 = hn::Zero(d_state);
VS sum2 = hn::Zero(ds); VState sum3 = hn::Zero(d_state);
VS sum3 = hn::Zero(ds); VState comp0 = hn::Zero(d_state);
VS comp0 = hn::Zero(ds); VState comp1 = hn::Zero(d_state);
VS comp1 = hn::Zero(ds); VState comp2 = hn::Zero(d_state);
VS comp2 = hn::Zero(ds); VState comp3 = hn::Zero(d_state);
VS comp3 = hn::Zero(ds);
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d_raw);
size_t i = 0; size_t i = 0;
if (vec.num >= 4 * N) { if (v.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) { for (; i <= v.num - 4 * N; i += 4 * N) {
Decompress2(d, vec, i + 0 * N, v0, v1); Decompress2(d_raw, v, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3); Decompress2(d_raw, v, i + 2 * N, v2, v3);
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, kernel.Update4(d_raw, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2,
comp0, comp1, comp2, comp3); sum3, comp0, comp1, comp2, comp3);
} }
} }
size_t remaining = vec.num - i; size_t remaining = v.num - i;
HWY_DASSERT(remaining < 4 * N); HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
DecompressAndZeroPad(d, vec, i, padded_v, remaining); DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded. // 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
const V v0 = hn::Load(d, padded_v + padded_pos); const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
kernel.Update1(d, v0, v0, sum0, comp0); kernel.Update1(d_raw, v0, v0, sum0, comp0);
} }
} }
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
comp3);
} }
// Functor called for each tensor, which compresses and stores them along with // Functor called for each tensor, which compresses and stores them along with

View File

@ -40,15 +40,20 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Our naming convention for dot product arguments is `w` and `v`, in that
// order. This originated in `MatVec`, which computed dot products of a
// compressed "weight" type, and `BF16/float` "vectors". This implementation no
// longer restricts the types of the arguments, but we keep the names for
// consistency, also because there is still a `w_ofs` but not a `v_ofs`.
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. This is large when there are many // Returns 2 * sum(|w.*v|) / |sum(w.*v)|. The log2 of this value
// similar-magnitude and opposite-sign elements. See // approximates the number of mantissa bits required for accurate computations.
// https://en.wikipedia.org/wiki/Condition_number. // See https://en.wikipedia.org/wiki/Condition_number.
template <typename WeightT, typename VecT> template <typename WT, typename VT>
HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w, HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w,
const VecT* HWY_RESTRICT v, const VT* HWY_RESTRICT v, size_t num) {
size_t num) {
PROFILER_FUNC; PROFILER_FUNC;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
@ -104,9 +109,8 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
} }
// Same, but for a single vector - just skips the product. // Same, but for a single vector - just skips the product.
template <typename VecT> template <typename VT>
HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v, HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
size_t num) {
PROFILER_FUNC; PROFILER_FUNC;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
@ -153,12 +157,60 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
return cond; return cond;
} }
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. 10 ops is too slow // f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed
// for compute-limited Matmul but might be OK for attention. // of f32 FMA. Only usable if `CanDecompressToDouble<WT, VT>()`.
// Also supports bf16 inputs, used by matvec-inl.h. struct DotKernelDouble {
template <typename VT, typename WT>
using Raw = double;
using State = double;
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
VR&, VR&, VR&, VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
sum1 = hn::MulAdd(w1, v1, sum1);
sum2 = hn::MulAdd(w2, v2, sum2);
sum3 = hn::MulAdd(w3, v3, sum3);
}
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
template <class D, typename WT, typename VT>
HWY_INLINE float DotDouble(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
}
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower
// than DotKernelDouble and about equally accurate.
struct DotKernelCompensated { struct DotKernelCompensated {
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)> // Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h.
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, // Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than
// promoting to `float` because it does not call `TwoProducts`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>;
using State = float;
// Raw = float
template <class DRaw, class VF = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw)>
HWY_INLINE void Update4(DRaw df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3, const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
@ -180,20 +232,20 @@ struct DotKernelCompensated {
comp3 = hn::Add(comp3, hn::Add(perr3, serr3)); comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
} }
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF), // Raw = BF16, State = float
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
HWY_INLINE void Update4(DBF /*dbf*/, const VBF w0, const VBF w1, const VBF w2, class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
const VBF w3, const VBF v0, const VBF v1, HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2,
const VBF v2, const VBF v3, VF& sum0, VF& sum1, const VR w3, const VR v0, const VR v1, const VR v2,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VF& comp3) const { VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
const DF df; const DS df;
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0); const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
const VF prod1 = WidenMulPairwiseAdd(df, w1, v1); const VS prod1 = WidenMulPairwiseAdd(df, w1, v1);
const VF prod2 = WidenMulPairwiseAdd(df, w2, v2); const VS prod2 = WidenMulPairwiseAdd(df, w2, v2);
const VF prod3 = WidenMulPairwiseAdd(df, w3, v3); const VS prod3 = WidenMulPairwiseAdd(df, w3, v3);
VF serr0, serr1, serr2, serr3; VS serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0); sum0 = TwoSums(df, prod0, sum0, serr0);
sum1 = TwoSums(df, prod1, sum1, serr1); sum1 = TwoSums(df, prod1, sum1, serr1);
sum2 = TwoSums(df, prod2, sum2, serr2); sum2 = TwoSums(df, prod2, sum2, serr2);
@ -205,6 +257,7 @@ struct DotKernelCompensated {
comp3 = hn::Add(comp3, serr3); comp3 = hn::Add(comp3, serr3);
} }
// Raw = float
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)> template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0, HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const { VF& comp0) const {
@ -217,22 +270,23 @@ struct DotKernelCompensated {
comp0 = hn::Add(comp0, hn::Add(perr0, serr0)); comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
} }
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF), // Raw = BF16, State = float
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>> template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
HWY_INLINE void Update1(DBF /*dbf*/, const VBF w0, const VBF v0, VF& sum0, class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
VF& comp0) const { HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
const DF df; VS& comp0) const {
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0); const DS df;
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
VF serr0; VS serr0;
sum0 = TwoSums(df, prod0, sum0, serr0); sum0 = TwoSums(df, prod0, sum0, serr0);
comp0 = hn::Add(comp0, serr0); comp0 = hn::Add(comp0, serr0);
} }
template <class DF, class VF = hn::Vec<DF>> template <class DS, class VS = hn::Vec<DS>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3, HWY_INLINE float Reduce(DS df, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes. // Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0); AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2); AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
@ -241,35 +295,38 @@ struct DotKernelCompensated {
} }
}; };
// Default kernel template <typename WT, typename VT>
template <class D, typename WeightT, typename VecT> using DotKernelDefault = hwy::If<CanDecompressToDouble<WT, VT>(),
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs, DotKernelDouble, DotKernelCompensated>;
const VecT* HWY_RESTRICT vec, size_t num) {
// `D` only serves to specify the vector size; its lane type is ignored.
template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated()); DotKernelDefault<WT, VT>());
} }
// Adapter for a single pointer, no bounds checking. // Adapter for two pointers, no bounds checking.
template <typename WeightT, typename VecT> template <typename WT, typename VT>
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec, HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) {
size_t num) { const hn::ScalableTag<VT> d;
const hn::ScalableTag<VecT> d;
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num); return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
} }
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <size_t kCapacity, typename VecT> template <size_t kCapacity, typename VT>
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs, HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
const VecT* vec, size_t num) { const VT* vec, size_t num) {
const hn::ScalableTag<VecT> d; const hn::ScalableTag<VT> d;
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
} }
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <typename MatT, size_t kCapacity, typename VecT> template <typename MatT, size_t kCapacity, typename VT>
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs, HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
const VecT* vec, size_t num) { const VT* vec, size_t num) {
const hn::ScalableTag<VecT> d; const hn::ScalableTag<VT> d;
return w.scale() * return w.scale() *
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
} }

View File

@ -120,6 +120,10 @@ void AssertLess(size_t variant, T actual, T max, int line) {
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}. // All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
struct DotKernelNaive { struct DotKernelNaive {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -150,51 +154,18 @@ struct DotKernelNaive {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs, HWY_INLINE float DotNaive(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) { const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive()); return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
} }
struct DotKernelDouble {
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE void Update4(DD dd, const VD w0, const VD w1, const VD w2,
const VD w3, const VD v0, const VD v1, const VD v2,
const VD v3, VD& sum0, VD& sum1, VD& sum2, VD& sum3,
VD&, VD&, VD&, VD&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
sum1 = hn::MulAdd(w1, v1, sum1);
sum2 = hn::MulAdd(w2, v2, sum2);
sum3 = hn::MulAdd(w3, v3, sum3);
}
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE void Update1(DD dd, const VD w0, const VD v0, VD& sum0,
VD&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE float Reduce(DD dd, VD& sum0, VD& sum1, VD& sum2, VD& sum3, VD&,
VD&, VD&, VD&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotDouble(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
const hn::Repartition<double, D> dd;
return DecompressAndCall(dd, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
}
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum. // https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
struct DotKernelKahan { struct DotKernelKahan {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -234,15 +205,15 @@ struct DotKernelKahan {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs, HWY_INLINE float DotKahan(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) { const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan()); return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
} }
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w, HWY_INLINE float DotCompensated(D d, const PackedSpan<const WT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec, size_t w_ofs, const VT* HWY_RESTRICT vec,
size_t num) { size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated()); DotKernelCompensated());
@ -250,6 +221,10 @@ HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
// Like Compensated, but FastTwoSum instead of TwoSum. // Like Compensated, but FastTwoSum instead of TwoSum.
struct DotKernelTwoProdFast { struct DotKernelTwoProdFast {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -296,9 +271,9 @@ struct DotKernelTwoProdFast {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w, HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec, size_t w_ofs, const VT* HWY_RESTRICT vec,
size_t num) { size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdFast()); DotKernelTwoProdFast());
@ -307,6 +282,10 @@ HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums // Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
// to TwoSums. // to TwoSums.
struct DotKernelMulTwoSum { struct DotKernelMulTwoSum {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -351,10 +330,9 @@ struct DotKernelMulTwoSum {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w, HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WT>& w, size_t w_ofs,
size_t w_ofs, const VecT* HWY_RESTRICT vec, const VT* HWY_RESTRICT vec, size_t num) {
size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelMulTwoSum()); DotKernelMulTwoSum());
} }
@ -362,6 +340,10 @@ HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10% // -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
// better (mul) than naive. // better (mul) than naive.
struct DotKernelTwoProdAdd { struct DotKernelTwoProdAdd {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -406,10 +388,9 @@ struct DotKernelTwoProdAdd {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w, HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WT>& w, size_t w_ofs,
size_t w_ofs, const VecT* HWY_RESTRICT vec, const VT* HWY_RESTRICT vec, size_t num) {
size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdAdd()); DotKernelTwoProdAdd());
} }
@ -417,6 +398,10 @@ HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
// From "SIMDizing Pairwise Sums". Slower and generally higher error than // From "SIMDizing Pairwise Sums". Slower and generally higher error than
// Kahan, but uses fewer regs. // Kahan, but uses fewer regs.
struct DotKernelPairwise { struct DotKernelPairwise {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -472,10 +457,9 @@ struct DotKernelPairwise {
mutable size_t num_ = 0; mutable size_t num_ = 0;
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w, HWY_INLINE float DotPairwise(D d, const PackedSpan<const WT>& w, size_t w_ofs,
size_t w_ofs, const VecT* HWY_RESTRICT vec, const VT* HWY_RESTRICT vec, size_t num) {
size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelPairwise()); DotKernelPairwise());
} }
@ -483,6 +467,10 @@ HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul // Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
// is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs. // is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs.
struct DotKernelComp2 { struct DotKernelComp2 {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)> template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2, const VF w3, const VF v0, const VF v1, const VF v2,
@ -567,17 +555,17 @@ struct DotKernelComp2 {
} }
}; };
template <class D, typename WeightT, typename VecT> template <class D, typename WT, typename VT>
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs, HWY_INLINE float DotComp2(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) { const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2()); return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
} }
template <class D, typename WeightT, typename VecT, HWY_IF_F32_D(D)> template <class D, typename WT, typename VT, HWY_IF_F32_D(D)>
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w, float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) { const VT* HWY_RESTRICT v, size_t num) {
// float inputs also support kDouble. // float inputs also support kDouble.
if constexpr (hwy::IsSame<WeightT, float>() && hwy::IsSame<VecT, float>()) { if constexpr (CanDecompressToDouble<WT, VT>()) {
if (variant == kDouble) return DotDouble(d, w, 0, v, num); if (variant == kDouble) return DotDouble(d, w, 0, v, num);
} }
@ -608,9 +596,9 @@ float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
// and round to nearest. See "Accurate and efficient floating point summation". // and round to nearest. See "Accurate and efficient floating point summation".
// Much too slow to be useful. Kept separate from the above kernels because it // Much too slow to be useful. Kept separate from the above kernels because it
// is used to compute their error. // is used to compute their error.
template <typename WeightT, typename VecT> template <typename WT, typename VT>
float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v, float ExactDot(const WT* HWY_RESTRICT w, const VT* HWY_RESTRICT v, size_t num,
size_t num, double* HWY_RESTRICT buf) { double* HWY_RESTRICT buf) {
PROFILER_FUNC; PROFILER_FUNC;
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
buf[i] = buf[i] =
@ -944,18 +932,17 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate // Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf. // Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
template <typename WeightT, typename VecT> template <typename WT, typename VT>
double GenerateIllConditionedInputs(const size_t num, WeightT* w, double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
VecT* HWY_RESTRICT v, std::mt19937& rng) { std::mt19937& rng) {
PROFILER_FUNC; PROFILER_FUNC;
const size_t half = HWY_MAX(1, num / 2); // generate at least one random const size_t half = HWY_MAX(1, num / 2); // generate at least one random
HWY_DASSERT(half != 0); HWY_DASSERT(half != 0);
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const PackedSpan<WT> w_span(w, num);
const PackedSpan<WeightT> w_span(w, num); // Regardless of WT and VT, we will accumulate into float. Multiplying
// Regardless of WeightT and VecT, we will accumulate into float. Multiplying
// two maximal inputs and accumulating `num` times is enough for some loss of // two maximal inputs and accumulating `num` times is enough for some loss of
// precision and condition numbers between 1E6-1E9, which is what we see for // precision and condition numbers between 1E6-1E9, which is what we see for
// Attention Dot and `RMSNormMul`. // Attention Dot and `RMSNormMul`.
@ -966,8 +953,8 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
for (size_t i = 0; i < half; ++i) { for (size_t i = 0; i < half; ++i) {
// Ensure the min and max exponents are used. // Ensure the min and max exponents are used.
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng); const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e)); w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
v[i] = hwy::ConvertScalarTo<VecT>(RandomFloat(rng) * (1 << e)); v[i] = hwy::ConvertScalarTo<VT>(RandomFloat(rng) * (1 << e));
} }
const float a_exp_step = const float a_exp_step =
@ -976,16 +963,16 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) { for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
const int e = static_cast<int>(a_exp); const int e = static_cast<int>(a_exp);
HWY_DASSERT(e >= 0); HWY_DASSERT(e >= 0);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e)); w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
const float r = RandomFloat(rng) * (1 << e); const float r = RandomFloat(rng) * (1 << e);
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) { if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
v[i] = hwy::ConvertScalarTo<VecT>(0.0f); v[i] = hwy::ConvertScalarTo<VT>(0.0f);
} else { } else {
// This is called >100K times. DotCompensated is much faster than ExactDot // This is called >100K times. DotCompensated is much faster than ExactDot
// and just about as accurate. // and just about as accurate.
const float exact = const float exact =
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i); DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
v[i] = hwy::ConvertScalarTo<VecT>( v[i] = hwy::ConvertScalarTo<VT>(
r - exact / hwy::ConvertScalarTo<float>(w[i])); r - exact / hwy::ConvertScalarTo<float>(w[i]));
} }
} }

View File

@ -142,11 +142,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
namespace detail { namespace detail {
// Shared by RMSNorm and RMSNormInplace. // Shared by RMSNorm and RMSNormInplace.
template <typename VecT> template <typename VT>
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) { float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
const hn::ScalableTag<float> df; const hn::ScalableTag<float> d;
const float l2 = const float l2 =
DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated()); DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault<VT, VT>());
constexpr float kEps = 1e-6f; // avoid divide by zero constexpr float kEps = 1e-6f; // avoid divide by zero
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps); return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
} }
@ -504,6 +504,40 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size); MulByConstAndAdd(c, x, out, size, size);
} }
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums. Only usable if `CanDecompressToDouble<VT, VT>()`.
struct SumKernelDouble {
template <typename VT, typename WT>
using Raw = double;
using State = double;
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
sum0 = hn::Add(sum0, w0);
sum1 = hn::Add(sum1, w1);
sum2 = hn::Add(sum2, w2);
sum3 = hn::Add(sum3, w3);
}
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const {
sum0 = hn::Add(sum0, w0);
}
template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point // ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums // Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not // instead of FastTwoSums because the magnitude of the initial sum is not
@ -511,7 +545,13 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
// generation results. Note that Kahan summation differs in that it first adds // generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum* // comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains. // and comp* here have shorter dependency chains.
struct KernelCascadedSum { //
// This is slower than SumKernelDouble and about equally accurate.
struct SumKernelCascaded {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)> template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1, const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
@ -549,6 +589,17 @@ struct KernelCascadedSum {
} }
}; };
template <typename VT>
using SumKernelDefault = hwy::If<CanDecompressToDouble<VT, VT>(),
SumKernelDouble, SumKernelCascaded>;
template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>());
}
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) { const size_t mask_pos) {
PROFILER_FUNC; PROFILER_FUNC;
@ -582,8 +633,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// not make a huge difference. It halves the standard deviation of the sum of // not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the normalized probabilities from 1E-7 to 5E-8, but actually also changes
// the generated text after a few hundred tokens. // the generated text after a few hundred tokens.
const float sum_exp = const float sum_exp = Sum(d, x, mask_pos);
DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum());
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, mask_pos); MulByConst(mul, x, size, mask_pos);