diff --git a/compression/compress-inl.h b/compression/compress-inl.h index c29db0f..2d440f3 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -101,6 +101,19 @@ struct CompressTraits { raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N); } + template > + static HWY_INLINE void Load2(DD dd, const PackedSpan& packed, + const size_t packed_ofs, VD& raw0, VD& raw1) { + const hn::Rebind df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + // Two half loads are likely cheaper than one full + UpperHalf. + const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF); + const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF); + raw0 = hn::PromoteTo(dd, f0); + raw1 = hn::PromoteTo(dd, f1); + } + template static HWY_INLINE void DecompressAndZeroPad( DBF dbf, const PackedSpan& packed, const size_t packed_ofs, @@ -149,6 +162,30 @@ struct CompressTraits { hn::StoreU(vf, df, raw + i); // adds zero padding } } + + template + static HWY_INLINE void DecompressAndZeroPad( + DD dd, const PackedSpan& packed, const size_t packed_ofs, + double* HWY_RESTRICT raw, size_t num) { + const hn::Rebind df; + using VF = hn::Vec; + using VD = hn::Vec; + const size_t ND = hn::Lanes(dd); + + size_t i = 0; + if (num >= ND) { + for (; i <= num - ND; i += ND) { + const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i); + hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); + } + } + const size_t remaining = num - i; + HWY_DASSERT(remaining < ND); + if (HWY_UNLIKELY(remaining != 0)) { + const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining); + hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); // adds zero padding + } + } }; template <> @@ -460,12 +497,23 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, Traits::Store2(df, raw0, raw1, packed, packed_ofs); } -// Decompresses from any type of `packed`, to two float or BF16 vectors. +// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes +// for better error messages than "no matching function found". +template +HWY_INLINE void VerifyRawAndPacked() { + using TRaw = hn::TFromD; + constexpr bool kPackedF32 = hwy::IsSame, float>(); + // We can decompress any Packed to f32 or BF16, or f32 to f64. + static_assert(hwy::IsSameEither() || + (kPackedF32 && hwy::IsSame())); +} + +// Decompresses from any type of `packed`, to two vectors of `float/BF16`, or +// `double`, if `Packed` is `float`. template > HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, const size_t packed_ofs, VRaw& raw0, VRaw& raw1) { - using TRaw = hn::TFromD; - static_assert(hwy::IsSameEither()); + VerifyRawAndPacked(); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d)); using Traits = CompressTraits>; Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1); @@ -476,13 +524,14 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, // required to round `num` up to one vector, if it is not already. The caller is // responsible for scaling `raw` to the original range because `EmbedToken` // also wants to scale the decompressed elements. +// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`. template > HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, const size_t packed_ofs, TRaw* raw, size_t num) { - static_assert(hwy::IsSameEither()); - using Traits = CompressTraits>; + VerifyRawAndPacked(); packed.BoundsCheck(packed_ofs, num); + using Traits = CompressTraits>; Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); } @@ -495,34 +544,38 @@ HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, // `hwy/contrib/unroller`, but also supports compressed types with simpler // remainder handling thanks to `DecompressAndZeroPad`. // +// `D` can be BF16/float, or also double if `WeightT` and `VecT` are both float. // `w` can be any packed type, including NUQ, which requires a separate `w_ofs` -// rather than pointer arithmetic. `vec_aligned` can also be any type, but -// typically float or BF16. We omit a `v_ofs` because it is 0 in our use cases. +// rather than pointer arithmetic. `vec` can also be any type, but typically +// float or BF16. We omit a `v_ofs` because it is 0 in our use cases. // `num`, the number of elements to process, need not be a vector multiple. // // `kernel` is const& so we can pass an rvalue argument, but can contain // mutable state, though not vectors (see highway.h). We pass in the four -// loaded vectors plus eight *f32* state vectors, independent of `D`. +// loaded vectors plus eight state vectors. The state vectors' lane type is +// either `double` (required for DotKernelDouble) or `float`. template HWY_INLINE float DecompressAndCall(D d, const PackedSpan& w, const size_t w_ofs, const PackedSpan vec, const Kernel& kernel) { // Decompressed inputs + using T = hn::TFromD; using V = hn::Vec; V w0, w1, w2, w3, v0, v1, v2, v3; // State for Kernel - const hn::Repartition df; - using VF = hn::Vec; - VF sum0 = hn::Zero(df); - VF sum1 = hn::Zero(df); - VF sum2 = hn::Zero(df); - VF sum3 = hn::Zero(df); - VF comp0 = hn::Zero(df); - VF comp1 = hn::Zero(df); - VF comp2 = hn::Zero(df); - VF comp3 = hn::Zero(df); + using StateT = hwy::If(), double, float>; + const hn::Repartition ds; + using VS = hn::Vec; + VS sum0 = hn::Zero(ds); + VS sum1 = hn::Zero(ds); + VS sum2 = hn::Zero(ds); + VS sum3 = hn::Zero(ds); + VS comp0 = hn::Zero(ds); + VS comp1 = hn::Zero(ds); + VS comp2 = hn::Zero(ds); + VS comp3 = hn::Zero(ds); const size_t N = hn::Lanes(d); size_t i = 0; @@ -541,7 +594,6 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan& w, size_t remaining = vec.num - i; HWY_DASSERT(remaining < 4 * N); if (HWY_UNLIKELY(remaining != 0)) { - using T = hn::TFromD; HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)]; HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining); @@ -555,7 +607,7 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan& w, } } - return kernel.Reduce(df, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); + return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); } // Same as above, but single input array. Used by RMSNorm. @@ -563,20 +615,22 @@ template HWY_INLINE float DecompressAndCall(D d, const PackedSpan vec, const Kernel& kernel) { // Decompressed inputs + using T = hn::TFromD; using V = hn::Vec; V v0, v1, v2, v3; // State for Kernel - const hn::Repartition df; - using VF = hn::Vec; - VF sum0 = hn::Zero(d); - VF sum1 = hn::Zero(d); - VF sum2 = hn::Zero(d); - VF sum3 = hn::Zero(d); - VF comp0 = hn::Zero(d); - VF comp1 = hn::Zero(d); - VF comp2 = hn::Zero(d); - VF comp3 = hn::Zero(d); + using StateT = hwy::If(), double, float>; + const hn::Repartition ds; + using VS = hn::Vec; + VS sum0 = hn::Zero(ds); + VS sum1 = hn::Zero(ds); + VS sum2 = hn::Zero(ds); + VS sum3 = hn::Zero(ds); + VS comp0 = hn::Zero(ds); + VS comp1 = hn::Zero(ds); + VS comp2 = hn::Zero(ds); + VS comp3 = hn::Zero(ds); const size_t N = hn::Lanes(d); size_t i = 0; @@ -593,17 +647,17 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan vec, size_t remaining = vec.num - i; HWY_DASSERT(remaining < 4 * N); if (HWY_UNLIKELY(remaining != 0)) { - HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)]; + HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; DecompressAndZeroPad(d, vec, i, padded_v, remaining); // 1..4 whole vectors, possibly zero-padded. for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { - const VF v0 = hn::Load(d, padded_v + padded_pos); + const V v0 = hn::Load(d, padded_v + padded_pos); kernel.Update1(d, v0, v0, sum0, comp0); } } - return kernel.Reduce(d, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); + return kernel.Reduce(ds, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); } // Functor called for each tensor, which compresses and stores them along with diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 1f591ff..860644a 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -50,6 +50,8 @@ void ForeachRawType() { // The argument selects the type to decode to: BF16 or float. test(BF16()); test(float()); + // Do not include double because it is not supported as an input type - we + // would also have to implement double -> Packed Compress(). } template