mirror of https://github.com/google/gemma.cpp.git
Add double-precision dot variant
PiperOrigin-RevId: 679243590
This commit is contained in:
parent
71116daf64
commit
47eb80a90e
|
|
@ -101,6 +101,19 @@ struct CompressTraits<float> {
|
|||
raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N);
|
||||
}
|
||||
|
||||
template <class DD, HWY_IF_F64_D(DD), class VD = hn::Vec<DD>>
|
||||
static HWY_INLINE void Load2(DD dd, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, VD& raw0, VD& raw1) {
|
||||
const hn::Rebind<float, DD> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
// Two half loads are likely cheaper than one full + UpperHalf.
|
||||
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
|
||||
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
|
||||
raw0 = hn::PromoteTo(dd, f0);
|
||||
raw1 = hn::PromoteTo(dd, f1);
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
|
|
@ -149,6 +162,30 @@ struct CompressTraits<float> {
|
|||
hn::StoreU(vf, df, raw + i); // adds zero padding
|
||||
}
|
||||
}
|
||||
|
||||
template <class DD, HWY_IF_F64_D(DD)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DD dd, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
double* HWY_RESTRICT raw, size_t num) {
|
||||
const hn::Rebind<float, DD> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
const size_t ND = hn::Lanes(dd);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= ND) {
|
||||
for (; i <= num - ND; i += ND) {
|
||||
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
|
||||
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < ND);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
|
||||
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); // adds zero padding
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
|
|
@ -460,12 +497,23 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
|||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||
}
|
||||
|
||||
// Decompresses from any type of `packed`, to two float or BF16 vectors.
|
||||
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
||||
// for better error messages than "no matching function found".
|
||||
template <class DRaw, typename Packed>
|
||||
HWY_INLINE void VerifyRawAndPacked() {
|
||||
using TRaw = hn::TFromD<DRaw>;
|
||||
constexpr bool kPackedF32 = hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||
// We can decompress any Packed to f32 or BF16, or f32 to f64.
|
||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>() ||
|
||||
(kPackedF32 && hwy::IsSame<TRaw, double>()));
|
||||
}
|
||||
|
||||
// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or
|
||||
// `double`, if `Packed` is `float`.
|
||||
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
|
||||
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
|
||||
using TRaw = hn::TFromD<DRaw>;
|
||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>());
|
||||
VerifyRawAndPacked<DRaw, Packed>();
|
||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
|
||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
|
||||
|
|
@ -476,13 +524,14 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
|||
// required to round `num` up to one vector, if it is not already. The caller is
|
||||
// responsible for scaling `raw` to the original range because `EmbedToken`
|
||||
// also wants to scale the decompressed elements.
|
||||
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
|
||||
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, TRaw* raw,
|
||||
size_t num) {
|
||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>());
|
||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||
VerifyRawAndPacked<DRaw, Packed>();
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
||||
}
|
||||
|
||||
|
|
@ -495,34 +544,38 @@ HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
|||
// `hwy/contrib/unroller`, but also supports compressed types with simpler
|
||||
// remainder handling thanks to `DecompressAndZeroPad`.
|
||||
//
|
||||
// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float.
|
||||
// `w` can be any packed type, including NUQ, which requires a separate `w_ofs`
|
||||
// rather than pointer arithmetic. `vec_aligned` can also be any type, but
|
||||
// typically float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
|
||||
// rather than pointer arithmetic. `vec` can also be any type, but typically
|
||||
// float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
|
||||
// `num`, the number of elements to process, need not be a vector multiple.
|
||||
//
|
||||
// `kernel` is const& so we can pass an rvalue argument, but can contain
|
||||
// mutable state, though not vectors (see highway.h). We pass in the four
|
||||
// loaded vectors plus eight *f32* state vectors, independent of `D`.
|
||||
// loaded vectors plus eight state vectors. The state vectors' lane type is
|
||||
// either `double` (required for DotKernelDouble) or `float`.
|
||||
template <class D, typename WeightT, typename VecT, class Kernel>
|
||||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
||||
const size_t w_ofs,
|
||||
const PackedSpan<const VecT> vec,
|
||||
const Kernel& kernel) {
|
||||
// Decompressed inputs
|
||||
using T = hn::TFromD<D>;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V w0, w1, w2, w3, v0, v1, v2, v3;
|
||||
|
||||
// State for Kernel
|
||||
const hn::Repartition<float, D> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
VF comp0 = hn::Zero(df);
|
||||
VF comp1 = hn::Zero(df);
|
||||
VF comp2 = hn::Zero(df);
|
||||
VF comp3 = hn::Zero(df);
|
||||
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
|
||||
const hn::Repartition<StateT, D> ds;
|
||||
using VS = hn::Vec<decltype(ds)>;
|
||||
VS sum0 = hn::Zero(ds);
|
||||
VS sum1 = hn::Zero(ds);
|
||||
VS sum2 = hn::Zero(ds);
|
||||
VS sum3 = hn::Zero(ds);
|
||||
VS comp0 = hn::Zero(ds);
|
||||
VS comp1 = hn::Zero(ds);
|
||||
VS comp2 = hn::Zero(ds);
|
||||
VS comp3 = hn::Zero(ds);
|
||||
|
||||
const size_t N = hn::Lanes(d);
|
||||
size_t i = 0;
|
||||
|
|
@ -541,7 +594,6 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
|||
size_t remaining = vec.num - i;
|
||||
HWY_DASSERT(remaining < 4 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
using T = hn::TFromD<D>;
|
||||
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
|
||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
||||
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
|
||||
|
|
@ -555,7 +607,7 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
|||
}
|
||||
}
|
||||
|
||||
return kernel.Reduce(df, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
}
|
||||
|
||||
// Same as above, but single input array. Used by RMSNorm.
|
||||
|
|
@ -563,20 +615,22 @@ template <class D, typename VecT, class Kernel>
|
|||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
|
||||
const Kernel& kernel) {
|
||||
// Decompressed inputs
|
||||
using T = hn::TFromD<D>;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V v0, v1, v2, v3;
|
||||
|
||||
// State for Kernel
|
||||
const hn::Repartition<float, D> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(d);
|
||||
VF sum1 = hn::Zero(d);
|
||||
VF sum2 = hn::Zero(d);
|
||||
VF sum3 = hn::Zero(d);
|
||||
VF comp0 = hn::Zero(d);
|
||||
VF comp1 = hn::Zero(d);
|
||||
VF comp2 = hn::Zero(d);
|
||||
VF comp3 = hn::Zero(d);
|
||||
using StateT = hwy::If<hwy::IsSame<T, double>(), double, float>;
|
||||
const hn::Repartition<StateT, D> ds;
|
||||
using VS = hn::Vec<decltype(ds)>;
|
||||
VS sum0 = hn::Zero(ds);
|
||||
VS sum1 = hn::Zero(ds);
|
||||
VS sum2 = hn::Zero(ds);
|
||||
VS sum3 = hn::Zero(ds);
|
||||
VS comp0 = hn::Zero(ds);
|
||||
VS comp1 = hn::Zero(ds);
|
||||
VS comp2 = hn::Zero(ds);
|
||||
VS comp3 = hn::Zero(ds);
|
||||
|
||||
const size_t N = hn::Lanes(d);
|
||||
size_t i = 0;
|
||||
|
|
@ -593,17 +647,17 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
|
|||
size_t remaining = vec.num - i;
|
||||
HWY_DASSERT(remaining < 4 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)];
|
||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
||||
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
|
||||
|
||||
// 1..4 whole vectors, possibly zero-padded.
|
||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||
const VF v0 = hn::Load(d, padded_v + padded_pos);
|
||||
const V v0 = hn::Load(d, padded_v + padded_pos);
|
||||
kernel.Update1(d, v0, v0, sum0, comp0);
|
||||
}
|
||||
}
|
||||
|
||||
return kernel.Reduce(d, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
}
|
||||
|
||||
// Functor called for each tensor, which compresses and stores them along with
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ void ForeachRawType() {
|
|||
// The argument selects the type to decode to: BF16 or float.
|
||||
test(BF16());
|
||||
test(float());
|
||||
// Do not include double because it is not supported as an input type - we
|
||||
// would also have to implement double -> Packed Compress().
|
||||
}
|
||||
|
||||
template <template <class> class TestT>
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ enum { // alphabetical order for consistency and to avoid implying a preference
|
|||
kAddTwoSum,
|
||||
kComp2,
|
||||
kCompensated,
|
||||
kDouble,
|
||||
kKahan,
|
||||
kNaive,
|
||||
kOnlyTwoProd,
|
||||
|
|
@ -75,6 +76,8 @@ const char* VariantName(size_t variant) {
|
|||
return "comp2";
|
||||
case kCompensated:
|
||||
return "comp";
|
||||
case kDouble:
|
||||
return "double";
|
||||
case kKahan:
|
||||
return "kahan";
|
||||
case kNaive:
|
||||
|
|
@ -153,6 +156,43 @@ HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
|||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
|
||||
}
|
||||
|
||||
struct DotKernelDouble {
|
||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
||||
HWY_INLINE void Update4(DD dd, const VD w0, const VD w1, const VD w2,
|
||||
const VD w3, const VD v0, const VD v1, const VD v2,
|
||||
const VD v3, VD& sum0, VD& sum1, VD& sum2, VD& sum3,
|
||||
VD&, VD&, VD&, VD&) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
sum1 = hn::MulAdd(w1, v1, sum1);
|
||||
sum2 = hn::MulAdd(w2, v2, sum2);
|
||||
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||
}
|
||||
|
||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
||||
HWY_INLINE void Update1(DD dd, const VD w0, const VD v0, VD& sum0,
|
||||
VD&) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
}
|
||||
|
||||
template <class DD, class VD = hn::Vec<DD>, HWY_IF_F64_D(DD)>
|
||||
HWY_INLINE float Reduce(DD dd, VD& sum0, VD& sum1, VD& sum2, VD& sum3, VD&,
|
||||
VD&, VD&, VD&) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||
}
|
||||
};
|
||||
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotDouble(D d, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||
size_t num) {
|
||||
const hn::Repartition<double, D> dd;
|
||||
return DecompressAndCall(dd, w, w_ofs, MakeSpan(vec, num), DotKernelDouble());
|
||||
}
|
||||
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
||||
struct DotKernelKahan {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
|
|
@ -533,9 +573,14 @@ HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
|||
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
|
||||
}
|
||||
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
template <class D, typename WeightT, typename VecT, HWY_IF_F32_D(D)>
|
||||
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) {
|
||||
// float inputs also support kDouble.
|
||||
if constexpr (hwy::IsSame<WeightT, float>() && hwy::IsSame<VecT, float>()) {
|
||||
if (variant == kDouble) return DotDouble(d, w, 0, v, num);
|
||||
}
|
||||
|
||||
switch (variant) {
|
||||
case kAddTwoProd:
|
||||
return DotTwoProdFast(d, w, 0, v, num);
|
||||
|
|
@ -720,9 +765,11 @@ class DotStats {
|
|||
ASSERT_INSIDE(kComp2, 1.001f, s_muls[kComp2].Max(), 2.4f);
|
||||
ASSERT_INSIDE(kComp2, 1.0, s_muls[kComp2].GeometricMean(), 1.2);
|
||||
|
||||
// Compensated is very accurate.
|
||||
// Compensated and Double are very accurate.
|
||||
ASSERT_LESS(kCompensated, s_muls[kCompensated].Min(), 1.0f + 2E-6f);
|
||||
ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 2E-5f);
|
||||
ASSERT_LESS(kDouble, s_muls[kDouble].Min(), 1.0f + 2E-6f);
|
||||
ASSERT_LESS(kDouble, s_muls[kDouble].Max(), 1.0f + 2E-5f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
|
||||
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
|
||||
|
|
@ -751,9 +798,11 @@ class DotStats {
|
|||
ASSERT_INSIDE(kComp2, 1E-5, s_l1s[kComp2].Mean(), 9E-4);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_l1s[kComp2].Max(), 2.6E-3f);
|
||||
|
||||
// Compensated is very accurate.
|
||||
// Compensated and Double are very accurate.
|
||||
HWY_ASSERT(s_l1s[kCompensated].Min() == 0.0f);
|
||||
ASSERT_LESS(kCompensated, s_l1s[kCompensated].Max(), 3E-7f);
|
||||
HWY_ASSERT(s_l1s[kDouble].Min() == 0.0f);
|
||||
ASSERT_LESS(kDouble, s_l1s[kDouble].Max(), 3E-7f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||
ASSERT_INSIDE(kNaive, 1E-3, s_l1s[kNaive].Mean(), 2E-2);
|
||||
|
|
@ -778,9 +827,11 @@ class DotStats {
|
|||
ASSERT_INSIDE(kComp2, 2E-4, s_rels[kComp2].GeometricMean(), 3.7E-3);
|
||||
ASSERT_INSIDE(kComp2, 1E-5f, s_rels[kComp2].Max(), 0.4f);
|
||||
|
||||
// Compensated is very accurate.
|
||||
// Compensated and Double are very accurate.
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Min(), 1E-8f);
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
|
||||
ASSERT_LESS(kDouble, s_rels[kDouble].Min(), 1E-8f);
|
||||
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||
ASSERT_INSIDE(kNaive, 1E-3, s_rels[kNaive].GeometricMean(), 8E-2);
|
||||
|
|
@ -807,8 +858,9 @@ class DotStats {
|
|||
void CheckBwd() const {
|
||||
ASSERT_INSIDE(kComp2, 7E-10f, s_rels[kComp2].Max(), 0.4f);
|
||||
|
||||
// Compensated is very accurate.
|
||||
// Compensated and Double are very accurate.
|
||||
ASSERT_LESS(kCompensated, s_rels[kCompensated].Max(), 8E-6f);
|
||||
ASSERT_LESS(kDouble, s_rels[kDouble].Max(), 8E-6f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher than others
|
||||
ASSERT_INSIDE(kNaive, 1.5E-8f, s_rels[kNaive].Max(), 3080.f);
|
||||
|
|
@ -828,6 +880,7 @@ class DotStats {
|
|||
void CheckUlps() const {
|
||||
ASSERT_LESS(kComp2, s_ulps[kCompensated].Max(), 3.6E6f);
|
||||
ASSERT_LESS(kCompensated, s_ulps[kCompensated].Max(), 250.0f);
|
||||
ASSERT_LESS(kDouble, s_ulps[kDouble].Max(), 250.0f);
|
||||
ASSERT_LESS(kNaive, s_ulps[kNaive].Max(), 4E9f);
|
||||
ASSERT_LESS(kOnlyTwoProd, s_ulps[kOnlyTwoProd].Max(), 3E9f);
|
||||
ASSERT_LESS(kKahan, s_ulps[kKahan].Max(), 4E7f);
|
||||
|
|
@ -987,7 +1040,9 @@ struct TestShortDotsT {
|
|||
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf);
|
||||
float dots[kVariants];
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
dots[variant] = CallDot(df, variant, MakeConst(w), 0, v.ptr, num);
|
||||
// Here Packed is not always float, so we must not call kDouble.
|
||||
const size_t actual = (variant == kDouble) ? kCompensated : variant;
|
||||
dots[variant] = CallDot(df, actual, MakeConst(w), 0, v.ptr, num);
|
||||
|
||||
const float l1 = hwy::ScalarAbs(dots[variant] - dot_exact);
|
||||
s_l1[variant].Notify(l1);
|
||||
|
|
|
|||
Loading…
Reference in New Issue