From c5c9fc300cbda02d8b0d65a6e32229df70a4ea14 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 8 May 2024 07:08:25 -0700 Subject: [PATCH] Enable even/odd for SFP. Refs #166 Disable it for float32 because there is not enough benefit. PiperOrigin-RevId: 631788326 --- compression/compress-inl.h | 57 ++++++++++++++++++++----- compression/sfp-inl.h | 86 ++++++++++++++++++++++++++++++++++++++ compression/sfp_test.cc | 57 +++++++++++++++++++------ gemma/ops.h | 9 +++- 4 files changed, 183 insertions(+), 26 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 6d2e39f..4bef8c7 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -58,7 +58,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; - static constexpr bool kSupportsEvenOdd = false; + static constexpr bool kSupportsEvenOdd = false; // unnecessary template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -230,7 +230,8 @@ struct CompressTraits { static HWY_INLINE float DotEO( const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs, const float* HWY_RESTRICT vec_aligned, size_t num) { - HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0); + HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && + (num % (hn::Lanes(df32) * 2)) == 0); HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0); HWY_DASSERT(hn::IsAligned(df32, vec_aligned)); @@ -273,13 +274,13 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = SfpStream; - static constexpr bool kSupportsEvenOdd = false; + static constexpr bool kSupportsEvenOdd = true; template - static HWY_INLINE void Compress(DF df, const float* in, size_t num, - CompressPerThread& tls, - size_t /*out_capacity*/, MatT* out, - size_t out_ofs) { + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, + size_t num, CompressPerThread& tls, + size_t /*out_capacity*/, + MatT* HWY_RESTRICT out, size_t out_ofs) { SfpCodec::Enc(df, in, num, out + out_ofs); if (COMPRESS_STATS) { @@ -295,15 +296,21 @@ struct CompressTraits { } template - static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* in, - size_t in_ofs, OutT* out, size_t num) { + static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + OutT* HWY_RESTRICT out, size_t num) { SfpCodec::Dec(d, in + in_ofs, num, out); } template - static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, const MatT* in, - size_t in_ofs, const VecT* vec_aligned, + static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, + const MatT* HWY_RESTRICT in, size_t in_ofs, + const VecT* HWY_RESTRICT vec_aligned, size_t num) { + HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); + HWY_DASSERT((in_ofs % hn::Lanes(df)) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + using VF = hn::Vec; VF sum0 = hn::Zero(df); VF sum1 = hn::Zero(df); @@ -318,6 +325,34 @@ struct CompressTraits { sum0 = hn::Add(sum0, sum2); return hn::ReduceSum(df, sum0); } + + // Computes the dot product of an even-odd deinterleaved, f32 or bf16 + // `vec_aligned` and a column-major matrix `in`. `vec_aligned` should be + // aligned and alternate even-indexed `hn::Lanes(df)` elements followed by + // odd-indexed `hn::Lanes(df)` elements. + template + static HWY_INLINE float DotEO(const DF df, const MatT* HWY_RESTRICT in, + size_t in_ofs, + const VecT* HWY_RESTRICT vec_aligned, + size_t num) { + HWY_DASSERT(num >= (hn::Lanes(df) * 2) && (num % (hn::Lanes(df) * 2)) == 0); + HWY_DASSERT((in_ofs % (hn::Lanes(df) * 2)) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + using VF = hn::Vec; + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum2 = hn::Zero(df); + VF sum3 = hn::Zero(df); + + SfpCodec::DotEO(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3); + + // Reduction tree: sum of all accumulators, then their lanes + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, sum2); + return hn::ReduceSum(df, sum0); + } }; template <> diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 438a1cc..4e3350e 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -475,6 +475,61 @@ class SfpCodec { } } + // Fused decode and dot product with even-odd bf16 into four f32 accumulators. + template + static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, + size_t num, + const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, + hn::Vec& sum0, hn::Vec& sum1, + hn::Vec& sum2, hn::Vec& sum3) { + const hn::Repartition d8; + const hn::Repartition dbf; + using V8 = hn::Vec; + using VBF = hn::Vec; + const size_t N16 = hn::Lanes(dbf); + HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16 + + HWY_UNROLL(1) + for (size_t i = 0; i < num; i += 2 * N16) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const VBF ve = hn::LoadU(dbf, vec_aligned + i); + const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16); + VBF be, bo; + DecEvenOdd(dbf, packed, be, bo); + sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1); + sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3); + } + } + + // Fused decode and dot product with even-odd f32 into four f32 accumulators. + template + static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, + size_t num, + const float* HWY_RESTRICT vec_aligned, + hn::Vec& sum0, hn::Vec& sum1, + hn::Vec& sum2, hn::Vec& sum3) { + const hn::Repartition d8; + using V8 = hn::Vec; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32 + + HWY_UNROLL(1) + for (size_t i = 0; i < num; i += 4 * NF) { + const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0); + const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1); + const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2); + const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3); + VF fe0, fo0, fe1, fo1; + DecEvenOddF(df, packed, fe0, fo0, fe1, fo1); + sum0 = hn::MulAdd(fe0, ve0, sum0); + sum1 = hn::MulAdd(fo0, vo0, sum1); + sum2 = hn::MulAdd(fe1, ve1, sum2); + sum3 = hn::MulAdd(fo1, vo1, sum3); + } + } + private: // Wrappers to avoid code duplication across float/bf16 input types and // the main loop/remainder. @@ -574,6 +629,37 @@ class SfpCodec { f2 = hn::PromoteLowerTo(df, bf1); f3 = hn::PromoteUpperTo(df, bf1); } + + template >> + static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec& even, + hn::Vec& odd) { + const hn::Repartition d8; + V8 lo, hi; + DecBytes(d8, packed, lo, hi); +#if HWY_MAJOR > 1 || HWY_MINOR >= 2 + even = hn::BitCast(dbf, hn::InterleaveEven(d8, lo, hi)); + odd = hn::BitCast(dbf, hn::InterleaveOdd(d8, lo, hi)); +#else + even = hn::BitCast(dbf, hn::OddEven(hn::DupEven(hi), lo)); + odd = hn::BitCast(dbf, hn::OddEven(hi, hn::DupOdd(lo))); +#endif + } + + template >> + static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec& even0, + hn::Vec& odd0, hn::Vec& even1, + hn::Vec& odd1) { + const hn::Repartition dbf; + using VBF = hn::Vec; + VBF even_bf, odd_bf; + DecEvenOdd(dbf, packed, even_bf, odd_bf); + even0 = hn::PromoteLowerTo(df, even_bf); + odd0 = hn::PromoteLowerTo(df, odd_bf); + even1 = hn::PromoteUpperTo(df, even_bf); + odd1 = hn::PromoteUpperTo(df, odd_bf); + } }; // SfpCodec // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 8e50098..1324075 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -401,11 +401,13 @@ struct TestDot { HWY_INLINE void operator()(T /*unused*/, D d) { const hn::Repartition df; const size_t num = 1024; // not too many for GeometricMean overflow. + const size_t N = hn::Lanes(d); auto in = hwy::AllocateAligned(num); auto dec = hwy::AllocateAligned(num); auto vec = hwy::AllocateAligned(num); + auto vec_eo = hwy::AllocateAligned(num); auto sfp = hwy::AllocateAligned(num); - HWY_ASSERT(in && dec && vec && sfp); + HWY_ASSERT(in && dec && vec && vec_eo && sfp); // Generate inputs and verify their distribution. hwy::RandomState rng; @@ -422,27 +424,54 @@ struct TestDot { } VerifyGaussian(in_stats); + // Convert vec to even/odd for DotEO + for (size_t i = 0; i < num; i += 2 * N) { + hn::Vec ve, vo; + hn::LoadInterleaved2(d, vec.get() + i, ve, vo); + hn::Store(ve, d, vec_eo.get() + i + 0); + hn::Store(vo, d, vec_eo.get() + i + N); + } + SfpCodec::Enc(d, in.get(), num, sfp.get()); // Compute dot product without decompression. float actual = 0.0f; + float actual_eo = 0.0f; double elapsed = hwy::HighestValue(); + double elapsed_eo = hwy::HighestValue(); for (size_t rep = 0; rep < 200; ++rep) { - hn::Vec sum0 = hn::Zero(df); - hn::Vec sum1 = hn::Zero(df); - hn::Vec sum2 = hn::Zero(df); - hn::Vec sum3 = hn::Zero(df); - const double t0 = hwy::platform::Now(); - SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3); - const double t1 = hwy::platform::Now(); - elapsed = HWY_MIN(elapsed, t1 - t0); - sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); - actual = hn::ReduceSum(df, sum0); + { + hn::Vec sum0 = hn::Zero(df); + hn::Vec sum1 = hn::Zero(df); + hn::Vec sum2 = hn::Zero(df); + hn::Vec sum3 = hn::Zero(df); + const double t0 = hwy::platform::Now(); + SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3); + const double t1 = hwy::platform::Now(); + elapsed = HWY_MIN(elapsed, t1 - t0); + sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); + actual = hn::ReduceSum(df, sum0); + } + { + hn::Vec sum0 = hn::Zero(df); + hn::Vec sum1 = hn::Zero(df); + hn::Vec sum2 = hn::Zero(df); + hn::Vec sum3 = hn::Zero(df); + const double t0 = hwy::platform::Now(); + SfpCodec::DotEO(df, sfp.get(), num, vec_eo.get(), sum0, sum1, sum2, + sum3); + const double t1 = hwy::platform::Now(); + elapsed_eo = HWY_MIN(elapsed_eo, t1 - t0); + sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); + actual_eo = hn::ReduceSum(df, sum0); + } } SfpCodec::Dec(d, sfp.get(), num, dec.get()); - fprintf(stderr, "Vec %zu Dot %zu-bit %.2f MB/s\n", Lanes(d) * sizeof(T), - sizeof(T) * 8, num * sizeof(T) * 1E-6 / elapsed); + fprintf(stderr, "Vec %zu Dot %zu-bit %.2f ; %.2f MB/s\n", + Lanes(d) * sizeof(T), sizeof(T) * 8, + num * sizeof(T) * 1E-6 / elapsed, + num * sizeof(T) * 1E-6 / elapsed_eo); // Exact and decompressed dot products for comparison. float exact = 0.0f; // using original input @@ -479,6 +508,8 @@ struct TestDot { HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio)); // Decompressed and uncompressed dot should match exactly. HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f)); + // Even/odd dot should also match + HWY_ASSERT(gcpp::IsNear(actual, actual_eo, 1E-4f)); // Geomean of ratios for each i should be very close to one. HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0)); diff --git a/gemma/ops.h b/gemma/ops.h index 1be0170..7a8f240 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -25,6 +25,7 @@ #include #include // std::enable_if_t +#include "compression/sfp.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -286,8 +287,12 @@ HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs, PROFILER_ZONE("MatVecAdd"); #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 - if constexpr (CompressTraits::kSupportsEvenOdd && - hwy::IsSameEither()) { + using MatT = typename ArrayT::value_type; + // Sfp -> float does not benefit enough to recoup the cost of ToEvenOddF32. + if constexpr (CompressTraits::kSupportsEvenOdd && + hwy::IsSameEither() && + !(hwy::IsSame() && + hwy::IsSame())) { ToEvenOddF32(vec_aligned, kInner, even_odd); detail::MatVecAddInner( mat, mat_ofs, even_odd, add, out, pool);