Use f64 Dot and sum in softmax - faster than Cascaded

Also let the kernel specify the Raw and State types,
rename WeightT/VecT -> WT/VT.

PiperOrigin-RevId: 680464427
This commit is contained in:
Jan Wassenberg 2024-09-30 01:21:37 -07:00 committed by Copybara-Service
parent 47eb80a90e
commit 5e812f07f5
4 changed files with 342 additions and 232 deletions

View File

@ -120,7 +120,6 @@ struct CompressTraits<float> {
BF16* HWY_RESTRICT raw, size_t num) {
const hn::Repartition<float, decltype(dbf)> df;
using VF = hn::Vec<decltype(df)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
@ -169,7 +168,6 @@ struct CompressTraits<float> {
double* HWY_RESTRICT raw, size_t num) {
const hn::Rebind<float, DD> df;
using VF = hn::Vec<decltype(df)>;
using VD = hn::Vec<decltype(dd)>;
const size_t ND = hn::Lanes(dd);
size_t i = 0;
@ -413,8 +411,6 @@ struct CompressTraits<NuqStream> {
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<D>& raw0,
hn::Vec<D>& raw1) {
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
using V8 = hn::Vec<decltype(d8)>;
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
}
@ -497,23 +493,39 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
}
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
// TODO: `DotKernelDouble` requires both inputs to be `float` because currently
// only `CompressTraits<float>` can `Decompress2` to `double`. It is not yet
// clear whether we want to implement this for other Packed types.
template <typename WT, typename VT>
constexpr bool CanDecompressToDouble() {
return HWY_HAVE_FLOAT64 && IsF32<WT>() && IsF32<VT>();
}
namespace detail {
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
// for better error messages than "no matching function found".
template <class DRaw, typename Packed>
HWY_INLINE void VerifyRawAndPacked() {
HWY_INLINE void VerifyRawAndPackedForDecompress() {
using TRaw = hn::TFromD<DRaw>;
constexpr bool kPackedF32 = hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
// We can decompress any Packed to f32 or BF16, or f32 to f64.
static_assert(hwy::IsSameEither<TRaw, float, BF16>() ||
(kPackedF32 && hwy::IsSame<TRaw, double>()));
(IsF32<Packed>() && hwy::IsSame<TRaw, double>()));
}
} // namespace detail
// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or
// `double`, if `Packed` is `float`.
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
VerifyRawAndPacked<DRaw, Packed>();
detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
@ -529,135 +541,139 @@ template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
const size_t packed_ofs, TRaw* raw,
size_t num) {
VerifyRawAndPacked<DRaw, Packed>();
detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
packed.BoundsCheck(packed_ofs, num);
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
}
// Decompresses to the type specified by `D` from each of two arrays in groups
// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector
// multiple, then calls `kernel.Update1` for the remaining vectors. Returns
// `kernel.Reduce`.
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
// them to `kernel.Update1`, then returns `kernel.Reduce`. `v.num` is not
// required to be a multiple of the vector length.
//
// This is useful for implementing dot products, and similar to
// `hwy/contrib/unroller`, but also supports compressed types with simpler
// remainder handling thanks to `DecompressAndZeroPad`.
//
// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float.
// `w` can be any packed type, including NUQ, which requires a separate `w_ofs`
// rather than pointer arithmetic. `vec` can also be any type, but typically
// float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
// `num`, the number of elements to process, need not be a vector multiple.
// Both `w` and `v` can be any packed type. To support random access in `w`
// even if it is `NuqStream`, we ignore `w.num` and provide a `w_ofs`, but no
// `v_ofs` because it is always 0 in our use cases. `D` only serves to specify
// the vector size/fraction.
//
// `kernel` is const& so we can pass an rvalue argument, but can contain
// mutable state, though not vectors (see highway.h). We pass in the four
// loaded vectors plus eight state vectors. The state vectors' lane type is
// either `double` (required for DotKernelDouble) or `float`.
template <class D, typename WeightT, typename VecT, class Kernel>
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
// mutable state, though not vectors (see highway.h). In addition to the groups
// of four input vectors, we pass eight state vectors with lane type specified
// by `Kernel::State`, which is typically `float` but may differ if `Raw` is
// `double`, or `WT` and `VT` are `BF16`.
//
// Decoupling decompression and remainder handling from the actual usage of the
// vectors makes it easier to implement various dot product and sum algorithms.
// This is similar to `hwy/contrib/unroller`, but less general and relies on
// `DecompressAndZeroPad`.
template <class D, typename WT, typename VT, class Kernel>
HWY_INLINE float DecompressAndCall(D, const PackedSpan<const WT>& w,
const size_t w_ofs,
const PackedSpan<const VecT> vec,
const PackedSpan<const VT> v,
const Kernel& kernel) {
// Decompressed inputs
using T = hn::TFromD<D>;
using V = hn::Vec<decltype(d)>;
V w0, w1, w2, w3, v0, v1, v2, v3;
using Raw = typename Kernel::template Raw<WT, VT>;
const hn::Repartition<Raw, D> d_raw;
using VRaw = hn::Vec<decltype(d_raw)>;
VRaw w0, w1, w2, w3, v0, v1, v2, v3;
// State for Kernel
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
const hn::Repartition<StateT, D> ds;
using VS = hn::Vec<decltype(ds)>;
VS sum0 = hn::Zero(ds);
VS sum1 = hn::Zero(ds);
VS sum2 = hn::Zero(ds);
VS sum3 = hn::Zero(ds);
VS comp0 = hn::Zero(ds);
VS comp1 = hn::Zero(ds);
VS comp2 = hn::Zero(ds);
VS comp3 = hn::Zero(ds);
const hn::Repartition<typename Kernel::State, D> d_state;
using VState = hn::Vec<decltype(d_state)>;
VState sum0 = hn::Zero(d_state);
VState sum1 = hn::Zero(d_state);
VState sum2 = hn::Zero(d_state);
VState sum3 = hn::Zero(d_state);
VState comp0 = hn::Zero(d_state);
VState comp1 = hn::Zero(d_state);
VState comp2 = hn::Zero(d_state);
VState comp3 = hn::Zero(d_state);
const size_t N = hn::Lanes(d);
const size_t N = hn::Lanes(d_raw);
size_t i = 0;
if (vec.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) {
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
Decompress2(d, vec, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3);
if (v.num >= 4 * N) {
for (; i <= v.num - 4 * N; i += 4 * N) {
Decompress2(d_raw, w, w_ofs + i + 0 * N, w0, w1);
Decompress2(d_raw, w, w_ofs + i + 2 * N, w2, w3);
Decompress2(d_raw, v, i + 0 * N, v0, v1);
Decompress2(d_raw, v, i + 2 * N, v2, v3);
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
comp0, comp1, comp2, comp3);
kernel.Update4(d_raw, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2,
sum3, comp0, comp1, comp2, comp3);
}
}
size_t remaining = vec.num - i;
size_t remaining = v.num - i;
HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
HWY_ALIGN Raw padded_w[4 * hn::MaxLanes(d_raw)];
HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
DecompressAndZeroPad(d_raw, w, w_ofs + i, padded_w, remaining);
DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
const V w0 = hn::Load(d, padded_w + padded_pos);
const V v0 = hn::Load(d, padded_v + padded_pos);
kernel.Update1(d, w0, v0, sum0, comp0);
const VRaw w0 = hn::Load(d_raw, padded_w + padded_pos);
const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
kernel.Update1(d_raw, w0, v0, sum0, comp0);
}
}
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
comp3);
}
// Same as above, but single input array. Used by RMSNorm.
template <class D, typename VecT, class Kernel>
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
template <class D, typename VT, class Kernel>
HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
const Kernel& kernel) {
// Decompressed inputs
using T = hn::TFromD<D>;
using V = hn::Vec<decltype(d)>;
V v0, v1, v2, v3;
using Raw = typename Kernel::template Raw<VT, VT>;
const hn::Repartition<Raw, D> d_raw;
using VRaw = hn::Vec<decltype(d_raw)>;
VRaw v0, v1, v2, v3;
// State for Kernel
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
const hn::Repartition<StateT, D> ds;
using VS = hn::Vec<decltype(ds)>;
VS sum0 = hn::Zero(ds);
VS sum1 = hn::Zero(ds);
VS sum2 = hn::Zero(ds);
VS sum3 = hn::Zero(ds);
VS comp0 = hn::Zero(ds);
VS comp1 = hn::Zero(ds);
VS comp2 = hn::Zero(ds);
VS comp3 = hn::Zero(ds);
const hn::Repartition<typename Kernel::State, D> d_state;
using VState = hn::Vec<decltype(d_state)>;
VState sum0 = hn::Zero(d_state);
VState sum1 = hn::Zero(d_state);
VState sum2 = hn::Zero(d_state);
VState sum3 = hn::Zero(d_state);
VState comp0 = hn::Zero(d_state);
VState comp1 = hn::Zero(d_state);
VState comp2 = hn::Zero(d_state);
VState comp3 = hn::Zero(d_state);
const size_t N = hn::Lanes(d);
const size_t N = hn::Lanes(d_raw);
size_t i = 0;
if (vec.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) {
Decompress2(d, vec, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3);
if (v.num >= 4 * N) {
for (; i <= v.num - 4 * N; i += 4 * N) {
Decompress2(d_raw, v, i + 0 * N, v0, v1);
Decompress2(d_raw, v, i + 2 * N, v2, v3);
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
comp0, comp1, comp2, comp3);
kernel.Update4(d_raw, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2,
sum3, comp0, comp1, comp2, comp3);
}
}
size_t remaining = vec.num - i;
size_t remaining = v.num - i;
HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
const V v0 = hn::Load(d, padded_v + padded_pos);
kernel.Update1(d, v0, v0, sum0, comp0);
const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
kernel.Update1(d_raw, v0, v0, sum0, comp0);
}
}
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
comp3);
}
// Functor called for each tensor, which compresses and stores them along with

View File

@ -40,15 +40,20 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Our naming convention for dot product arguments is `w` and `v`, in that
// order. This originated in `MatVec`, which computed dot products of a
// compressed "weight" type, and `BF16/float` "vectors". This implementation no
// longer restricts the types of the arguments, but we keep the names for
// consistency, also because there is still a `w_ofs` but not a `v_ofs`.
//------------------------------------------------------------------------------
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. This is large when there are many
// similar-magnitude and opposite-sign elements. See
// https://en.wikipedia.org/wiki/Condition_number.
template <typename WeightT, typename VecT>
HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
const VecT* HWY_RESTRICT v,
size_t num) {
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. The log2 of this value
// approximates the number of mantissa bits required for accurate computations.
// See https://en.wikipedia.org/wiki/Condition_number.
template <typename WT, typename VT>
HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w,
const VT* HWY_RESTRICT v, size_t num) {
PROFILER_FUNC;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
@ -104,9 +109,8 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
}
// Same, but for a single vector - just skips the product.
template <typename VecT>
HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
size_t num) {
template <typename VT>
HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
PROFILER_FUNC;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
@ -153,12 +157,60 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
return cond;
}
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. 10 ops is too slow
// for compute-limited Matmul but might be OK for attention.
// Also supports bf16 inputs, used by matvec-inl.h.
// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 FMA. Only usable if `CanDecompressToDouble<WT, VT>()`.
struct DotKernelDouble {
template <typename VT, typename WT>
using Raw = double;
using State = double;
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
VR&, VR&, VR&, VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
sum1 = hn::MulAdd(w1, v1, sum1);
sum2 = hn::MulAdd(w2, v2, sum2);
sum3 = hn::MulAdd(w3, v3, sum3);
}
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
template <class D, typename WT, typename VT>
HWY_INLINE float DotDouble(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
}
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower
// than DotKernelDouble and about equally accurate.
struct DotKernelCompensated {
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
// Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h.
// Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than
// promoting to `float` because it does not call `TwoProducts`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>;
using State = float;
// Raw = float
template <class DRaw, class VF = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw)>
HWY_INLINE void Update4(DRaw df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
@ -180,20 +232,20 @@ struct DotKernelCompensated {
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
}
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DBF /*dbf*/, const VBF w0, const VBF w1, const VBF w2,
const VBF w3, const VBF v0, const VBF v1,
const VBF v2, const VBF v3, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const {
const DF df;
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
const VF prod1 = WidenMulPairwiseAdd(df, w1, v1);
const VF prod2 = WidenMulPairwiseAdd(df, w2, v2);
const VF prod3 = WidenMulPairwiseAdd(df, w3, v3);
// Raw = BF16, State = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
const DS df;
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
const VS prod1 = WidenMulPairwiseAdd(df, w1, v1);
const VS prod2 = WidenMulPairwiseAdd(df, w2, v2);
const VS prod3 = WidenMulPairwiseAdd(df, w3, v3);
VF serr0, serr1, serr2, serr3;
VS serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum1 = TwoSums(df, prod1, sum1, serr1);
sum2 = TwoSums(df, prod2, sum2, serr2);
@ -205,6 +257,7 @@ struct DotKernelCompensated {
comp3 = hn::Add(comp3, serr3);
}
// Raw = float
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
@ -217,22 +270,23 @@ struct DotKernelCompensated {
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
}
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DBF /*dbf*/, const VBF w0, const VBF v0, VF& sum0,
VF& comp0) const {
const DF df;
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
// Raw = BF16, State = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
VS& comp0) const {
const DS df;
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
VF serr0;
VS serr0;
sum0 = TwoSums(df, prod0, sum0, serr0);
comp0 = hn::Add(comp0, serr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
template <class DS, class VS = hn::Vec<DS>>
HWY_INLINE float Reduce(DS df, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
@ -241,35 +295,38 @@ struct DotKernelCompensated {
}
};
// Default kernel
template <class D, typename WeightT, typename VecT>
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) {
template <typename WT, typename VT>
using DotKernelDefault = hwy::If<CanDecompressToDouble<WT, VT>(),
DotKernelDouble, DotKernelCompensated>;
// `D` only serves to specify the vector size; its lane type is ignored.
template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated());
DotKernelDefault<WT, VT>());
}
// Adapter for a single pointer, no bounds checking.
template <typename WeightT, typename VecT>
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec,
size_t num) {
const hn::ScalableTag<VecT> d;
// Adapter for two pointers, no bounds checking.
template <typename WT, typename VT>
HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) {
const hn::ScalableTag<VT> d;
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
}
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <size_t kCapacity, typename VecT>
template <size_t kCapacity, typename VT>
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
const VecT* vec, size_t num) {
const hn::ScalableTag<VecT> d;
const VT* vec, size_t num) {
const hn::ScalableTag<VT> d;
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
}
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <typename MatT, size_t kCapacity, typename VecT>
template <typename MatT, size_t kCapacity, typename VT>
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
const VecT* vec, size_t num) {
const hn::ScalableTag<VecT> d;
const VT* vec, size_t num) {
const hn::ScalableTag<VT> d;
return w.scale() *
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
}

View File

@ -120,6 +120,10 @@ void AssertLess(size_t variant, T actual, T max, int line) {
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
struct DotKernelNaive {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -150,51 +154,18 @@ struct DotKernelNaive {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotNaive(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
}
struct DotKernelDouble {
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE void Update4(DD dd, const VD w0, const VD w1, const VD w2,
const VD w3, const VD v0, const VD v1, const VD v2,
const VD v3, VD& sum0, VD& sum1, VD& sum2, VD& sum3,
VD&, VD&, VD&, VD&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
sum1 = hn::MulAdd(w1, v1, sum1);
sum2 = hn::MulAdd(w2, v2, sum2);
sum3 = hn::MulAdd(w3, v3, sum3);
}
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE void Update1(DD dd, const VD w0, const VD v0, VD& sum0,
VD&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
HWY_INLINE float Reduce(DD dd, VD& sum0, VD& sum1, VD& sum2, VD& sum3, VD&,
VD&, VD&, VD&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotDouble(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
const hn::Repartition<double, D> dd;
return DecompressAndCall(dd, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
}
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
struct DotKernelKahan {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -234,15 +205,15 @@ struct DotKernelKahan {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotKahan(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
}
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
template <class D, typename WT, typename VT>
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WT>& w,
size_t w_ofs, const VT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated());
@ -250,6 +221,10 @@ HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
// Like Compensated, but FastTwoSum instead of TwoSum.
struct DotKernelTwoProdFast {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -296,9 +271,9 @@ struct DotKernelTwoProdFast {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
template <class D, typename WT, typename VT>
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WT>& w,
size_t w_ofs, const VT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdFast());
@ -307,6 +282,10 @@ HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
// to TwoSums.
struct DotKernelMulTwoSum {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -351,10 +330,9 @@ struct DotKernelMulTwoSum {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelMulTwoSum());
}
@ -362,6 +340,10 @@ HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
// better (mul) than naive.
struct DotKernelTwoProdAdd {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -406,10 +388,9 @@ struct DotKernelTwoProdAdd {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdAdd());
}
@ -417,6 +398,10 @@ HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
// From "SIMDizing Pairwise Sums". Slower and generally higher error than
// Kahan, but uses fewer regs.
struct DotKernelPairwise {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -472,10 +457,9 @@ struct DotKernelPairwise {
mutable size_t num_ = 0;
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelPairwise());
}
@ -483,6 +467,10 @@ HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
// is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs.
struct DotKernelComp2 {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
@ -567,17 +555,17 @@ struct DotKernelComp2 {
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec, size_t num) {
template <class D, typename WT, typename VT>
HWY_INLINE float DotComp2(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
}
template <class D, typename WeightT, typename VecT, HWY_IF_F32_D(D)>
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) {
template <class D, typename WT, typename VT, HWY_IF_F32_D(D)>
float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT v, size_t num) {
// float inputs also support kDouble.
if constexpr (hwy::IsSame<WeightT, float>() && hwy::IsSame<VecT, float>()) {
if constexpr (CanDecompressToDouble<WT, VT>()) {
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
}
@ -608,9 +596,9 @@ float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
// and round to nearest. See "Accurate and efficient floating point summation".
// Much too slow to be useful. Kept separate from the above kernels because it
// is used to compute their error.
template <typename WeightT, typename VecT>
float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v,
size_t num, double* HWY_RESTRICT buf) {
template <typename WT, typename VT>
float ExactDot(const WT* HWY_RESTRICT w, const VT* HWY_RESTRICT v, size_t num,
double* HWY_RESTRICT buf) {
PROFILER_FUNC;
for (size_t i = 0; i < num; ++i) {
buf[i] =
@ -944,18 +932,17 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
template <typename WeightT, typename VecT>
double GenerateIllConditionedInputs(const size_t num, WeightT* w,
VecT* HWY_RESTRICT v, std::mt19937& rng) {
template <typename WT, typename VT>
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
std::mt19937& rng) {
PROFILER_FUNC;
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
HWY_DASSERT(half != 0);
const hn::ScalableTag<float> df;
const PackedSpan<WT> w_span(w, num);
const PackedSpan<WeightT> w_span(w, num);
// Regardless of WeightT and VecT, we will accumulate into float. Multiplying
// Regardless of WT and VT, we will accumulate into float. Multiplying
// two maximal inputs and accumulating `num` times is enough for some loss of
// precision and condition numbers between 1E6-1E9, which is what we see for
// Attention Dot and `RMSNormMul`.
@ -966,8 +953,8 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
for (size_t i = 0; i < half; ++i) {
// Ensure the min and max exponents are used.
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
v[i] = hwy::ConvertScalarTo<VecT>(RandomFloat(rng) * (1 << e));
w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
v[i] = hwy::ConvertScalarTo<VT>(RandomFloat(rng) * (1 << e));
}
const float a_exp_step =
@ -976,16 +963,16 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
const int e = static_cast<int>(a_exp);
HWY_DASSERT(e >= 0);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
const float r = RandomFloat(rng) * (1 << e);
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
v[i] = hwy::ConvertScalarTo<VecT>(0.0f);
v[i] = hwy::ConvertScalarTo<VT>(0.0f);
} else {
// This is called >100K times. DotCompensated is much faster than ExactDot
// and just about as accurate.
const float exact =
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
v[i] = hwy::ConvertScalarTo<VecT>(
v[i] = hwy::ConvertScalarTo<VT>(
r - exact / hwy::ConvertScalarTo<float>(w[i]));
}
}

View File

@ -142,11 +142,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
namespace detail {
// Shared by RMSNorm and RMSNormInplace.
template <typename VecT>
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
const hn::ScalableTag<float> df;
template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
const hn::ScalableTag<float> d;
const float l2 =
DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated());
DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault<VT, VT>());
constexpr float kEps = 1e-6f; // avoid divide by zero
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
}
@ -504,6 +504,40 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums. Only usable if `CanDecompressToDouble<VT, VT>()`.
struct SumKernelDouble {
template <typename VT, typename WT>
using Raw = double;
using State = double;
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
sum0 = hn::Add(sum0, w0);
sum1 = hn::Add(sum1, w1);
sum2 = hn::Add(sum2, w2);
sum3 = hn::Add(sum3, w3);
}
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const {
sum0 = hn::Add(sum0, w0);
}
template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not
@ -511,7 +545,13 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
// generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains.
struct KernelCascadedSum {
//
// This is slower than SumKernelDouble and about equally accurate.
struct SumKernelCascaded {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
@ -549,6 +589,17 @@ struct KernelCascadedSum {
}
};
template <typename VT>
using SumKernelDefault = hwy::If<CanDecompressToDouble<VT, VT>(),
SumKernelDouble, SumKernelCascaded>;
template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>());
}
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {
PROFILER_FUNC;
@ -582,8 +633,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
// the generated text after a few hundred tokens.
const float sum_exp =
DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum());
const float sum_exp = Sum(d, x, mask_pos);
// Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, mask_pos);