mirror of https://github.com/google/gemma.cpp.git
Use f64 Dot and sum in softmax - faster than Cascaded
Also let the kernel specify the Raw and State types, rename WeightT/VecT -> WT/VT. PiperOrigin-RevId: 680464427
This commit is contained in:
parent
47eb80a90e
commit
5e812f07f5
|
|
@ -120,7 +120,6 @@ struct CompressTraits<float> {
|
||||||
BF16* HWY_RESTRICT raw, size_t num) {
|
BF16* HWY_RESTRICT raw, size_t num) {
|
||||||
const hn::Repartition<float, decltype(dbf)> df;
|
const hn::Repartition<float, decltype(dbf)> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|
@ -169,7 +168,6 @@ struct CompressTraits<float> {
|
||||||
double* HWY_RESTRICT raw, size_t num) {
|
double* HWY_RESTRICT raw, size_t num) {
|
||||||
const hn::Rebind<float, DD> df;
|
const hn::Rebind<float, DD> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
using VD = hn::Vec<decltype(dd)>;
|
|
||||||
const size_t ND = hn::Lanes(dd);
|
const size_t ND = hn::Lanes(dd);
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|
@ -413,8 +411,6 @@ struct CompressTraits<NuqStream> {
|
||||||
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||||
const size_t packed_ofs, hn::Vec<D>& raw0,
|
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||||
hn::Vec<D>& raw1) {
|
hn::Vec<D>& raw1) {
|
||||||
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
|
|
||||||
using V8 = hn::Vec<decltype(d8)>;
|
|
||||||
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -497,23 +493,39 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Packed>
|
||||||
|
constexpr bool IsF32() {
|
||||||
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: `DotKernelDouble` requires both inputs to be `float` because currently
|
||||||
|
// only `CompressTraits<float>` can `Decompress2` to `double`. It is not yet
|
||||||
|
// clear whether we want to implement this for other Packed types.
|
||||||
|
template <typename WT, typename VT>
|
||||||
|
constexpr bool CanDecompressToDouble() {
|
||||||
|
return HWY_HAVE_FLOAT64 && IsF32<WT>() && IsF32<VT>();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
||||||
// for better error messages than "no matching function found".
|
// for better error messages than "no matching function found".
|
||||||
template <class DRaw, typename Packed>
|
template <class DRaw, typename Packed>
|
||||||
HWY_INLINE void VerifyRawAndPacked() {
|
HWY_INLINE void VerifyRawAndPackedForDecompress() {
|
||||||
using TRaw = hn::TFromD<DRaw>;
|
using TRaw = hn::TFromD<DRaw>;
|
||||||
constexpr bool kPackedF32 = hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
|
||||||
// We can decompress any Packed to f32 or BF16, or f32 to f64.
|
// We can decompress any Packed to f32 or BF16, or f32 to f64.
|
||||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>() ||
|
static_assert(hwy::IsSameEither<TRaw, float, BF16>() ||
|
||||||
(kPackedF32 && hwy::IsSame<TRaw, double>()));
|
(IsF32<Packed>() && hwy::IsSame<TRaw, double>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or
|
// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or
|
||||||
// `double`, if `Packed` is `float`.
|
// `double`, if `Packed` is `float`.
|
||||||
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
|
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
|
||||||
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
|
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
|
||||||
VerifyRawAndPacked<DRaw, Packed>();
|
detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
|
||||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
|
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
|
||||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||||
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
|
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
|
||||||
|
|
@ -529,135 +541,139 @@ template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||||
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs, TRaw* raw,
|
const size_t packed_ofs, TRaw* raw,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
VerifyRawAndPacked<DRaw, Packed>();
|
detail::VerifyRawAndPackedForDecompress<DRaw, Packed>();
|
||||||
packed.BoundsCheck(packed_ofs, num);
|
packed.BoundsCheck(packed_ofs, num);
|
||||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||||
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses to the type specified by `D` from each of two arrays in groups
|
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
|
||||||
// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector
|
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
|
||||||
// multiple, then calls `kernel.Update1` for the remaining vectors. Returns
|
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
|
||||||
// `kernel.Reduce`.
|
// them to `kernel.Update1`, then returns `kernel.Reduce`. `v.num` is not
|
||||||
|
// required to be a multiple of the vector length.
|
||||||
//
|
//
|
||||||
// This is useful for implementing dot products, and similar to
|
// Both `w` and `v` can be any packed type. To support random access in `w`
|
||||||
// `hwy/contrib/unroller`, but also supports compressed types with simpler
|
// even if it is `NuqStream`, we ignore `w.num` and provide a `w_ofs`, but no
|
||||||
// remainder handling thanks to `DecompressAndZeroPad`.
|
// `v_ofs` because it is always 0 in our use cases. `D` only serves to specify
|
||||||
//
|
// the vector size/fraction.
|
||||||
// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float.
|
|
||||||
// `w` can be any packed type, including NUQ, which requires a separate `w_ofs`
|
|
||||||
// rather than pointer arithmetic. `vec` can also be any type, but typically
|
|
||||||
// float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
|
|
||||||
// `num`, the number of elements to process, need not be a vector multiple.
|
|
||||||
//
|
//
|
||||||
// `kernel` is const& so we can pass an rvalue argument, but can contain
|
// `kernel` is const& so we can pass an rvalue argument, but can contain
|
||||||
// mutable state, though not vectors (see highway.h). We pass in the four
|
// mutable state, though not vectors (see highway.h). In addition to the groups
|
||||||
// loaded vectors plus eight state vectors. The state vectors' lane type is
|
// of four input vectors, we pass eight state vectors with lane type specified
|
||||||
// either `double` (required for DotKernelDouble) or `float`.
|
// by `Kernel::State`, which is typically `float` but may differ if `Raw` is
|
||||||
template <class D, typename WeightT, typename VecT, class Kernel>
|
// `double`, or `WT` and `VT` are `BF16`.
|
||||||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
//
|
||||||
|
// Decoupling decompression and remainder handling from the actual usage of the
|
||||||
|
// vectors makes it easier to implement various dot product and sum algorithms.
|
||||||
|
// This is similar to `hwy/contrib/unroller`, but less general and relies on
|
||||||
|
// `DecompressAndZeroPad`.
|
||||||
|
template <class D, typename WT, typename VT, class Kernel>
|
||||||
|
HWY_INLINE float DecompressAndCall(D, const PackedSpan<const WT>& w,
|
||||||
const size_t w_ofs,
|
const size_t w_ofs,
|
||||||
const PackedSpan<const VecT> vec,
|
const PackedSpan<const VT> v,
|
||||||
const Kernel& kernel) {
|
const Kernel& kernel) {
|
||||||
// Decompressed inputs
|
// Decompressed inputs
|
||||||
using T = hn::TFromD<D>;
|
using Raw = typename Kernel::template Raw<WT, VT>;
|
||||||
using V = hn::Vec<decltype(d)>;
|
const hn::Repartition<Raw, D> d_raw;
|
||||||
V w0, w1, w2, w3, v0, v1, v2, v3;
|
using VRaw = hn::Vec<decltype(d_raw)>;
|
||||||
|
VRaw w0, w1, w2, w3, v0, v1, v2, v3;
|
||||||
|
|
||||||
// State for Kernel
|
// State for Kernel
|
||||||
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
|
const hn::Repartition<typename Kernel::State, D> d_state;
|
||||||
const hn::Repartition<StateT, D> ds;
|
using VState = hn::Vec<decltype(d_state)>;
|
||||||
using VS = hn::Vec<decltype(ds)>;
|
VState sum0 = hn::Zero(d_state);
|
||||||
VS sum0 = hn::Zero(ds);
|
VState sum1 = hn::Zero(d_state);
|
||||||
VS sum1 = hn::Zero(ds);
|
VState sum2 = hn::Zero(d_state);
|
||||||
VS sum2 = hn::Zero(ds);
|
VState sum3 = hn::Zero(d_state);
|
||||||
VS sum3 = hn::Zero(ds);
|
VState comp0 = hn::Zero(d_state);
|
||||||
VS comp0 = hn::Zero(ds);
|
VState comp1 = hn::Zero(d_state);
|
||||||
VS comp1 = hn::Zero(ds);
|
VState comp2 = hn::Zero(d_state);
|
||||||
VS comp2 = hn::Zero(ds);
|
VState comp3 = hn::Zero(d_state);
|
||||||
VS comp3 = hn::Zero(ds);
|
|
||||||
|
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d_raw);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
if (vec.num >= 4 * N) {
|
if (v.num >= 4 * N) {
|
||||||
for (; i <= vec.num - 4 * N; i += 4 * N) {
|
for (; i <= v.num - 4 * N; i += 4 * N) {
|
||||||
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
|
Decompress2(d_raw, w, w_ofs + i + 0 * N, w0, w1);
|
||||||
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
|
Decompress2(d_raw, w, w_ofs + i + 2 * N, w2, w3);
|
||||||
Decompress2(d, vec, i + 0 * N, v0, v1);
|
Decompress2(d_raw, v, i + 0 * N, v0, v1);
|
||||||
Decompress2(d, vec, i + 2 * N, v2, v3);
|
Decompress2(d_raw, v, i + 2 * N, v2, v3);
|
||||||
|
|
||||||
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
kernel.Update4(d_raw, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2,
|
||||||
comp0, comp1, comp2, comp3);
|
sum3, comp0, comp1, comp2, comp3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t remaining = vec.num - i;
|
size_t remaining = v.num - i;
|
||||||
HWY_DASSERT(remaining < 4 * N);
|
HWY_DASSERT(remaining < 4 * N);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
|
HWY_ALIGN Raw padded_w[4 * hn::MaxLanes(d_raw)];
|
||||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
|
||||||
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
|
DecompressAndZeroPad(d_raw, w, w_ofs + i, padded_w, remaining);
|
||||||
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
|
DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
|
||||||
|
|
||||||
// 1..4 whole vectors, possibly zero-padded.
|
// 1..4 whole vectors, possibly zero-padded.
|
||||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||||
const V w0 = hn::Load(d, padded_w + padded_pos);
|
const VRaw w0 = hn::Load(d_raw, padded_w + padded_pos);
|
||||||
const V v0 = hn::Load(d, padded_v + padded_pos);
|
const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
|
||||||
kernel.Update1(d, w0, v0, sum0, comp0);
|
kernel.Update1(d_raw, w0, v0, sum0, comp0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
|
||||||
|
comp3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above, but single input array. Used by RMSNorm.
|
// Same as above, but single input array. Used by RMSNorm.
|
||||||
template <class D, typename VecT, class Kernel>
|
template <class D, typename VT, class Kernel>
|
||||||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
|
HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
||||||
const Kernel& kernel) {
|
const Kernel& kernel) {
|
||||||
// Decompressed inputs
|
// Decompressed inputs
|
||||||
using T = hn::TFromD<D>;
|
using Raw = typename Kernel::template Raw<VT, VT>;
|
||||||
using V = hn::Vec<decltype(d)>;
|
const hn::Repartition<Raw, D> d_raw;
|
||||||
V v0, v1, v2, v3;
|
using VRaw = hn::Vec<decltype(d_raw)>;
|
||||||
|
VRaw v0, v1, v2, v3;
|
||||||
|
|
||||||
// State for Kernel
|
// State for Kernel
|
||||||
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
|
const hn::Repartition<typename Kernel::State, D> d_state;
|
||||||
const hn::Repartition<StateT, D> ds;
|
using VState = hn::Vec<decltype(d_state)>;
|
||||||
using VS = hn::Vec<decltype(ds)>;
|
VState sum0 = hn::Zero(d_state);
|
||||||
VS sum0 = hn::Zero(ds);
|
VState sum1 = hn::Zero(d_state);
|
||||||
VS sum1 = hn::Zero(ds);
|
VState sum2 = hn::Zero(d_state);
|
||||||
VS sum2 = hn::Zero(ds);
|
VState sum3 = hn::Zero(d_state);
|
||||||
VS sum3 = hn::Zero(ds);
|
VState comp0 = hn::Zero(d_state);
|
||||||
VS comp0 = hn::Zero(ds);
|
VState comp1 = hn::Zero(d_state);
|
||||||
VS comp1 = hn::Zero(ds);
|
VState comp2 = hn::Zero(d_state);
|
||||||
VS comp2 = hn::Zero(ds);
|
VState comp3 = hn::Zero(d_state);
|
||||||
VS comp3 = hn::Zero(ds);
|
|
||||||
|
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d_raw);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
if (vec.num >= 4 * N) {
|
if (v.num >= 4 * N) {
|
||||||
for (; i <= vec.num - 4 * N; i += 4 * N) {
|
for (; i <= v.num - 4 * N; i += 4 * N) {
|
||||||
Decompress2(d, vec, i + 0 * N, v0, v1);
|
Decompress2(d_raw, v, i + 0 * N, v0, v1);
|
||||||
Decompress2(d, vec, i + 2 * N, v2, v3);
|
Decompress2(d_raw, v, i + 2 * N, v2, v3);
|
||||||
|
|
||||||
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
kernel.Update4(d_raw, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2,
|
||||||
comp0, comp1, comp2, comp3);
|
sum3, comp0, comp1, comp2, comp3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t remaining = vec.num - i;
|
size_t remaining = v.num - i;
|
||||||
HWY_DASSERT(remaining < 4 * N);
|
HWY_DASSERT(remaining < 4 * N);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
HWY_ALIGN Raw padded_v[4 * hn::MaxLanes(d_raw)];
|
||||||
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
|
DecompressAndZeroPad(d_raw, v, i, padded_v, remaining);
|
||||||
|
|
||||||
// 1..4 whole vectors, possibly zero-padded.
|
// 1..4 whole vectors, possibly zero-padded.
|
||||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||||
const V v0 = hn::Load(d, padded_v + padded_pos);
|
const VRaw v0 = hn::Load(d_raw, padded_v + padded_pos);
|
||||||
kernel.Update1(d, v0, v0, sum0, comp0);
|
kernel.Update1(d_raw, v0, v0, sum0, comp0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
return kernel.Reduce(d_state, sum0, sum1, sum2, sum3, comp0, comp1, comp2,
|
||||||
|
comp3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functor called for each tensor, which compresses and stores them along with
|
// Functor called for each tensor, which compresses and stores them along with
|
||||||
|
|
|
||||||
165
ops/dot-inl.h
165
ops/dot-inl.h
|
|
@ -40,15 +40,20 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
// Our naming convention for dot product arguments is `w` and `v`, in that
|
||||||
|
// order. This originated in `MatVec`, which computed dot products of a
|
||||||
|
// compressed "weight" type, and `BF16/float` "vectors". This implementation no
|
||||||
|
// longer restricts the types of the arguments, but we keep the names for
|
||||||
|
// consistency, also because there is still a `w_ofs` but not a `v_ofs`.
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. This is large when there are many
|
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. The log2 of this value
|
||||||
// similar-magnitude and opposite-sign elements. See
|
// approximates the number of mantissa bits required for accurate computations.
|
||||||
// https://en.wikipedia.org/wiki/Condition_number.
|
// See https://en.wikipedia.org/wiki/Condition_number.
|
||||||
template <typename WeightT, typename VecT>
|
template <typename WT, typename VT>
|
||||||
HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
|
HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w,
|
||||||
const VecT* HWY_RESTRICT v,
|
const VT* HWY_RESTRICT v, size_t num) {
|
||||||
size_t num) {
|
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
@ -104,9 +109,8 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same, but for a single vector - just skips the product.
|
// Same, but for a single vector - just skips the product.
|
||||||
template <typename VecT>
|
template <typename VT>
|
||||||
HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
|
HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
|
||||||
size_t num) {
|
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
@ -153,12 +157,60 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
|
||||||
return cond;
|
return cond;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. 10 ops is too slow
|
// f64 FMA, called for f32 inputs promoted to f64. Runs at about half the speed
|
||||||
// for compute-limited Matmul but might be OK for attention.
|
// of f32 FMA. Only usable if `CanDecompressToDouble<WT, VT>()`.
|
||||||
// Also supports bf16 inputs, used by matvec-inl.h.
|
struct DotKernelDouble {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = double;
|
||||||
|
using State = double;
|
||||||
|
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
|
||||||
|
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||||
|
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
|
||||||
|
VR&, VR&, VR&, VR&) const {
|
||||||
|
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||||
|
sum1 = hn::MulAdd(w1, v1, sum1);
|
||||||
|
sum2 = hn::MulAdd(w2, v2, sum2);
|
||||||
|
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
|
||||||
|
VR&) const {
|
||||||
|
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DState, class VS = hn::Vec<DState>, HWY_IF_F64_D(DState)>
|
||||||
|
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||||
|
VS&, VS&, VS&, VS&) const {
|
||||||
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
|
sum0 = hn::Add(sum0, sum1);
|
||||||
|
sum2 = hn::Add(sum2, sum3);
|
||||||
|
sum0 = hn::Add(sum0, sum2);
|
||||||
|
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class D, typename WT, typename VT>
|
||||||
|
HWY_INLINE float DotDouble(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. This is slower
|
||||||
|
// than DotKernelDouble and about equally accurate.
|
||||||
struct DotKernelCompensated {
|
struct DotKernelCompensated {
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
// Unlike other kernels, this also supports bf16 inputs, used by matvec-inl.h.
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
// Even when `!HWY_NATIVE_DOT_BF16`, the BF16 overload is still faster than
|
||||||
|
// promoting to `float` because it does not call `TwoProducts`.
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
|
// Raw = float
|
||||||
|
template <class DRaw, class VF = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw)>
|
||||||
|
HWY_INLINE void Update4(DRaw df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||||
|
|
@ -180,20 +232,20 @@ struct DotKernelCompensated {
|
||||||
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
|
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
|
// Raw = BF16, State = float
|
||||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||||
HWY_INLINE void Update4(DBF /*dbf*/, const VBF w0, const VBF w1, const VBF w2,
|
class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
|
||||||
const VBF w3, const VBF v0, const VBF v1,
|
HWY_INLINE void Update4(DRaw, const VR w0, const VR w1, const VR w2,
|
||||||
const VBF v2, const VBF v3, VF& sum0, VF& sum1,
|
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||||
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||||
VF& comp3) const {
|
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
|
||||||
const DF df;
|
const DS df;
|
||||||
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||||
const VF prod1 = WidenMulPairwiseAdd(df, w1, v1);
|
const VS prod1 = WidenMulPairwiseAdd(df, w1, v1);
|
||||||
const VF prod2 = WidenMulPairwiseAdd(df, w2, v2);
|
const VS prod2 = WidenMulPairwiseAdd(df, w2, v2);
|
||||||
const VF prod3 = WidenMulPairwiseAdd(df, w3, v3);
|
const VS prod3 = WidenMulPairwiseAdd(df, w3, v3);
|
||||||
|
|
||||||
VF serr0, serr1, serr2, serr3;
|
VS serr0, serr1, serr2, serr3;
|
||||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||||
sum1 = TwoSums(df, prod1, sum1, serr1);
|
sum1 = TwoSums(df, prod1, sum1, serr1);
|
||||||
sum2 = TwoSums(df, prod2, sum2, serr2);
|
sum2 = TwoSums(df, prod2, sum2, serr2);
|
||||||
|
|
@ -205,6 +257,7 @@ struct DotKernelCompensated {
|
||||||
comp3 = hn::Add(comp3, serr3);
|
comp3 = hn::Add(comp3, serr3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Raw = float
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||||
VF& comp0) const {
|
VF& comp0) const {
|
||||||
|
|
@ -217,22 +270,23 @@ struct DotKernelCompensated {
|
||||||
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
|
// Raw = BF16, State = float
|
||||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||||
HWY_INLINE void Update1(DBF /*dbf*/, const VBF w0, const VBF v0, VF& sum0,
|
class DS = hn::Repartition<float, DRaw>, class VS = hn::Vec<DS>>
|
||||||
VF& comp0) const {
|
HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
|
||||||
const DF df;
|
VS& comp0) const {
|
||||||
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
const DS df;
|
||||||
|
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||||
|
|
||||||
VF serr0;
|
VS serr0;
|
||||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||||
|
|
||||||
comp0 = hn::Add(comp0, serr0);
|
comp0 = hn::Add(comp0, serr0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DS, class VS = hn::Vec<DS>>
|
||||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
HWY_INLINE float Reduce(DS df, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
|
||||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||||
|
|
@ -241,35 +295,38 @@ struct DotKernelCompensated {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Default kernel
|
template <typename WT, typename VT>
|
||||||
template <class D, typename WeightT, typename VecT>
|
using DotKernelDefault = hwy::If<CanDecompressToDouble<WT, VT>(),
|
||||||
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
DotKernelDouble, DotKernelCompensated>;
|
||||||
const VecT* HWY_RESTRICT vec, size_t num) {
|
|
||||||
|
// `D` only serves to specify the vector size; its lane type is ignored.
|
||||||
|
template <class D, typename WT, typename VT>
|
||||||
|
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelCompensated());
|
DotKernelDefault<WT, VT>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for a single pointer, no bounds checking.
|
// Adapter for two pointers, no bounds checking.
|
||||||
template <typename WeightT, typename VecT>
|
template <typename WT, typename VT>
|
||||||
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec,
|
HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) {
|
||||||
size_t num) {
|
const hn::ScalableTag<VT> d;
|
||||||
const hn::ScalableTag<VecT> d;
|
|
||||||
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
|
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||||
template <size_t kCapacity, typename VecT>
|
template <size_t kCapacity, typename VT>
|
||||||
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
|
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
|
||||||
const VecT* vec, size_t num) {
|
const VT* vec, size_t num) {
|
||||||
const hn::ScalableTag<VecT> d;
|
const hn::ScalableTag<VT> d;
|
||||||
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||||
template <typename MatT, size_t kCapacity, typename VecT>
|
template <typename MatT, size_t kCapacity, typename VT>
|
||||||
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
|
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
|
||||||
const VecT* vec, size_t num) {
|
const VT* vec, size_t num) {
|
||||||
const hn::ScalableTag<VecT> d;
|
const hn::ScalableTag<VT> d;
|
||||||
return w.scale() *
|
return w.scale() *
|
||||||
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
151
ops/dot_test.cc
151
ops/dot_test.cc
|
|
@ -120,6 +120,10 @@ void AssertLess(size_t variant, T actual, T max, int line) {
|
||||||
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
|
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
|
||||||
|
|
||||||
struct DotKernelNaive {
|
struct DotKernelNaive {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -150,51 +154,18 @@ struct DotKernelNaive {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotNaive(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec, size_t num) {
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DotKernelDouble {
|
|
||||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
|
||||||
HWY_INLINE void Update4(DD dd, const VD w0, const VD w1, const VD w2,
|
|
||||||
const VD w3, const VD v0, const VD v1, const VD v2,
|
|
||||||
const VD v3, VD& sum0, VD& sum1, VD& sum2, VD& sum3,
|
|
||||||
VD&, VD&, VD&, VD&) const {
|
|
||||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
|
||||||
sum1 = hn::MulAdd(w1, v1, sum1);
|
|
||||||
sum2 = hn::MulAdd(w2, v2, sum2);
|
|
||||||
sum3 = hn::MulAdd(w3, v3, sum3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
|
||||||
HWY_INLINE void Update1(DD dd, const VD w0, const VD v0, VD& sum0,
|
|
||||||
VD&) const {
|
|
||||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
|
||||||
HWY_INLINE float Reduce(DD dd, VD& sum0, VD& sum1, VD& sum2, VD& sum3, VD&,
|
|
||||||
VD&, VD&, VD&) const {
|
|
||||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
|
||||||
sum0 = hn::Add(sum0, sum1);
|
|
||||||
sum2 = hn::Add(sum2, sum3);
|
|
||||||
sum0 = hn::Add(sum0, sum2);
|
|
||||||
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
|
||||||
HWY_INLINE float DotDouble(D d, const PackedSpan<const WeightT>& w,
|
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
|
||||||
size_t num) {
|
|
||||||
const hn::Repartition<double, D> dd;
|
|
||||||
return DecompressAndCall(dd, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
||||||
struct DotKernelKahan {
|
struct DotKernelKahan {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -234,15 +205,15 @@ struct DotKernelKahan {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotKahan(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec, size_t num) {
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WT>& w,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
size_t w_ofs, const VT* HWY_RESTRICT vec,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelCompensated());
|
DotKernelCompensated());
|
||||||
|
|
@ -250,6 +221,10 @@ HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
|
||||||
|
|
||||||
// Like Compensated, but FastTwoSum instead of TwoSum.
|
// Like Compensated, but FastTwoSum instead of TwoSum.
|
||||||
struct DotKernelTwoProdFast {
|
struct DotKernelTwoProdFast {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -296,9 +271,9 @@ struct DotKernelTwoProdFast {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WT>& w,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
size_t w_ofs, const VT* HWY_RESTRICT vec,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelTwoProdFast());
|
DotKernelTwoProdFast());
|
||||||
|
|
@ -307,6 +282,10 @@ HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
|
||||||
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
|
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
|
||||||
// to TwoSums.
|
// to TwoSums.
|
||||||
struct DotKernelMulTwoSum {
|
struct DotKernelMulTwoSum {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -351,10 +330,9 @@ struct DotKernelMulTwoSum {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
size_t num) {
|
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelMulTwoSum());
|
DotKernelMulTwoSum());
|
||||||
}
|
}
|
||||||
|
|
@ -362,6 +340,10 @@ HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
|
||||||
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
|
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
|
||||||
// better (mul) than naive.
|
// better (mul) than naive.
|
||||||
struct DotKernelTwoProdAdd {
|
struct DotKernelTwoProdAdd {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -406,10 +388,9 @@ struct DotKernelTwoProdAdd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
size_t num) {
|
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelTwoProdAdd());
|
DotKernelTwoProdAdd());
|
||||||
}
|
}
|
||||||
|
|
@ -417,6 +398,10 @@ HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
|
||||||
// From "SIMDizing Pairwise Sums". Slower and generally higher error than
|
// From "SIMDizing Pairwise Sums". Slower and generally higher error than
|
||||||
// Kahan, but uses fewer regs.
|
// Kahan, but uses fewer regs.
|
||||||
struct DotKernelPairwise {
|
struct DotKernelPairwise {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -472,10 +457,9 @@ struct DotKernelPairwise {
|
||||||
mutable size_t num_ = 0;
|
mutable size_t num_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
size_t num) {
|
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelPairwise());
|
DotKernelPairwise());
|
||||||
}
|
}
|
||||||
|
|
@ -483,6 +467,10 @@ HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
|
||||||
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
|
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
|
||||||
// is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs.
|
// is 1.02 vs 1.06, mean L1 is 1.21x better, and uses two fewer regs.
|
||||||
struct DotKernelComp2 {
|
struct DotKernelComp2 {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||||
|
|
@ -567,17 +555,17 @@ struct DotKernelComp2 {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WT, typename VT>
|
||||||
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotComp2(D d, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec, size_t num) {
|
const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT, HWY_IF_F32_D(D)>
|
template <class D, typename WT, typename VT, HWY_IF_F32_D(D)>
|
||||||
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
|
float CallDot(D d, size_t variant, const PackedSpan<const WT>& w, size_t w_ofs,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) {
|
const VT* HWY_RESTRICT v, size_t num) {
|
||||||
// float inputs also support kDouble.
|
// float inputs also support kDouble.
|
||||||
if constexpr (hwy::IsSame<WeightT, float>() && hwy::IsSame<VecT, float>()) {
|
if constexpr (CanDecompressToDouble<WT, VT>()) {
|
||||||
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
|
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -608,9 +596,9 @@ float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
|
||||||
// and round to nearest. See "Accurate and efficient floating point summation".
|
// and round to nearest. See "Accurate and efficient floating point summation".
|
||||||
// Much too slow to be useful. Kept separate from the above kernels because it
|
// Much too slow to be useful. Kept separate from the above kernels because it
|
||||||
// is used to compute their error.
|
// is used to compute their error.
|
||||||
template <typename WeightT, typename VecT>
|
template <typename WT, typename VT>
|
||||||
float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v,
|
float ExactDot(const WT* HWY_RESTRICT w, const VT* HWY_RESTRICT v, size_t num,
|
||||||
size_t num, double* HWY_RESTRICT buf) {
|
double* HWY_RESTRICT buf) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
buf[i] =
|
buf[i] =
|
||||||
|
|
@ -944,18 +932,17 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
||||||
|
|
||||||
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
|
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
|
||||||
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
||||||
template <typename WeightT, typename VecT>
|
template <typename WT, typename VT>
|
||||||
double GenerateIllConditionedInputs(const size_t num, WeightT* w,
|
double GenerateIllConditionedInputs(const size_t num, WT* w, VT* HWY_RESTRICT v,
|
||||||
VecT* HWY_RESTRICT v, std::mt19937& rng) {
|
std::mt19937& rng) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
||||||
HWY_DASSERT(half != 0);
|
HWY_DASSERT(half != 0);
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
const PackedSpan<WT> w_span(w, num);
|
||||||
|
|
||||||
const PackedSpan<WeightT> w_span(w, num);
|
// Regardless of WT and VT, we will accumulate into float. Multiplying
|
||||||
|
|
||||||
// Regardless of WeightT and VecT, we will accumulate into float. Multiplying
|
|
||||||
// two maximal inputs and accumulating `num` times is enough for some loss of
|
// two maximal inputs and accumulating `num` times is enough for some loss of
|
||||||
// precision and condition numbers between 1E6-1E9, which is what we see for
|
// precision and condition numbers between 1E6-1E9, which is what we see for
|
||||||
// Attention Dot and `RMSNormMul`.
|
// Attention Dot and `RMSNormMul`.
|
||||||
|
|
@ -966,8 +953,8 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
|
||||||
for (size_t i = 0; i < half; ++i) {
|
for (size_t i = 0; i < half; ++i) {
|
||||||
// Ensure the min and max exponents are used.
|
// Ensure the min and max exponents are used.
|
||||||
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
|
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
|
||||||
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
|
w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
|
||||||
v[i] = hwy::ConvertScalarTo<VecT>(RandomFloat(rng) * (1 << e));
|
v[i] = hwy::ConvertScalarTo<VT>(RandomFloat(rng) * (1 << e));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float a_exp_step =
|
const float a_exp_step =
|
||||||
|
|
@ -976,16 +963,16 @@ double GenerateIllConditionedInputs(const size_t num, WeightT* w,
|
||||||
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
|
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
|
||||||
const int e = static_cast<int>(a_exp);
|
const int e = static_cast<int>(a_exp);
|
||||||
HWY_DASSERT(e >= 0);
|
HWY_DASSERT(e >= 0);
|
||||||
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
|
w[i] = hwy::ConvertScalarTo<WT>(RandomFloat(rng) * (1 << e));
|
||||||
const float r = RandomFloat(rng) * (1 << e);
|
const float r = RandomFloat(rng) * (1 << e);
|
||||||
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
|
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
|
||||||
v[i] = hwy::ConvertScalarTo<VecT>(0.0f);
|
v[i] = hwy::ConvertScalarTo<VT>(0.0f);
|
||||||
} else {
|
} else {
|
||||||
// This is called >100K times. DotCompensated is much faster than ExactDot
|
// This is called >100K times. DotCompensated is much faster than ExactDot
|
||||||
// and just about as accurate.
|
// and just about as accurate.
|
||||||
const float exact =
|
const float exact =
|
||||||
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
|
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
|
||||||
v[i] = hwy::ConvertScalarTo<VecT>(
|
v[i] = hwy::ConvertScalarTo<VT>(
|
||||||
r - exact / hwy::ConvertScalarTo<float>(w[i]));
|
r - exact / hwy::ConvertScalarTo<float>(w[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -142,11 +142,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// Shared by RMSNorm and RMSNormInplace.
|
// Shared by RMSNorm and RMSNormInplace.
|
||||||
template <typename VecT>
|
template <typename VT>
|
||||||
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
|
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> d;
|
||||||
const float l2 =
|
const float l2 =
|
||||||
DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated());
|
DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault<VT, VT>());
|
||||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||||
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
||||||
}
|
}
|
||||||
|
|
@ -504,6 +504,40 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
MulByConstAndAdd(c, x, out, size, size);
|
MulByConstAndAdd(c, x, out, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
|
||||||
|
// of f32 sums. Only usable if `CanDecompressToDouble<VT, VT>()`.
|
||||||
|
struct SumKernelDouble {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = double;
|
||||||
|
using State = double;
|
||||||
|
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
|
||||||
|
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
|
||||||
|
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
|
||||||
|
sum0 = hn::Add(sum0, w0);
|
||||||
|
sum1 = hn::Add(sum1, w1);
|
||||||
|
sum2 = hn::Add(sum2, w2);
|
||||||
|
sum3 = hn::Add(sum3, w3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
|
||||||
|
VR& comp0) const {
|
||||||
|
sum0 = hn::Add(sum0, w0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DState, class VS = hn::Vec<DState>>
|
||||||
|
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||||
|
VS&, VS&, VS&, VS&) const {
|
||||||
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
|
sum0 = hn::Add(sum0, sum1);
|
||||||
|
sum2 = hn::Add(sum2, sum3);
|
||||||
|
sum0 = hn::Add(sum0, sum2);
|
||||||
|
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
|
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
|
||||||
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
|
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
|
||||||
// instead of FastTwoSums because the magnitude of the initial sum is not
|
// instead of FastTwoSums because the magnitude of the initial sum is not
|
||||||
|
|
@ -511,7 +545,13 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
// generation results. Note that Kahan summation differs in that it first adds
|
// generation results. Note that Kahan summation differs in that it first adds
|
||||||
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
||||||
// and comp* here have shorter dependency chains.
|
// and comp* here have shorter dependency chains.
|
||||||
struct KernelCascadedSum {
|
//
|
||||||
|
// This is slower than SumKernelDouble and about equally accurate.
|
||||||
|
struct SumKernelCascaded {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
|
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
|
||||||
|
|
@ -549,6 +589,17 @@ struct KernelCascadedSum {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename VT>
|
||||||
|
using SumKernelDefault = hwy::If<CanDecompressToDouble<VT, VT>(),
|
||||||
|
SumKernelDouble, SumKernelCascaded>;
|
||||||
|
|
||||||
|
template <class D, typename VT>
|
||||||
|
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
|
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
|
||||||
|
const hn::Repartition<Raw, D> d_raw;
|
||||||
|
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault<VT>());
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos) {
|
const size_t mask_pos) {
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
|
|
@ -582,8 +633,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
// not make a huge difference. It halves the standard deviation of the sum of
|
// not make a huge difference. It halves the standard deviation of the sum of
|
||||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||||
// the generated text after a few hundred tokens.
|
// the generated text after a few hundred tokens.
|
||||||
const float sum_exp =
|
const float sum_exp = Sum(d, x, mask_pos);
|
||||||
DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum());
|
|
||||||
// Double-precision reciprocal does not appear to affect the results.
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / sum_exp;
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, x, size, mask_pos);
|
MulByConst(mul, x, size, mask_pos);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue