From 0816a1070d3f6ca4e8bfe63b71732b29cea1f575 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Sun, 28 Apr 2024 17:53:16 -0700 Subject: [PATCH] Even-odd layout MatVecs for bf16 weights. --- compression/compress-inl.h | 80 +++++++++++++- gemma/ops.h | 210 +++++++++++++++++++++++++++++++------ 2 files changed, 252 insertions(+), 38 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index a6a4b7e..631dbd4 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,6 +51,20 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +template +HWY_INLINE void Bf16ToF32EO(const DF df32, + const hwy::bfloat16_t* HWY_RESTRICT in, + hn::Vec& v_even, + hn::Vec& v_odd) { + const hn::Repartition dbf16; + const hn::RebindToUnsigned du32; + + const auto odd = Set(du32, 0xFFFF0000u); + const auto interleaved = BitCast(du32, LoadU(dbf16, in)); + v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved)); + v_odd = BitCast(df32, And(interleaved, odd)); +} + // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -58,6 +72,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -111,6 +126,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = hwy::bfloat16_t; + static constexpr bool supports_eo = true; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -219,11 +235,60 @@ struct CompressTraits { // bf16*bf16. return hn::Dot::Compute(d_vec, vec_aligned, in + in_ofs, num); } + + // Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned` + // and a column- major matrix `in`. `vec_aligned` should be aligned and + // alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed + // `hn::Lanes(df32)` elements. + template + static HWY_INLINE float DotEO( + const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs, + const float* HWY_RESTRICT vec_aligned, size_t num) { + HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0); + HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0); + HWY_DASSERT(hn::IsAligned(df32, vec_aligned)); + + const hn::Repartition dbf16; + using VF32 = decltype(Zero(df32)); + const size_t N = Lanes(dbf16); + + VF32 sum0 = Zero(df32); + VF32 sum1 = Zero(df32); + VF32 sum2 = Zero(df32); + VF32 sum3 = Zero(df32); + + const hn::RebindToUnsigned du32; + using VU32 = hn::VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); + + VF32 be0, bo0, be1, bo1; + for (size_t i = 0; i < num; /* i += 2 * N */) { + const VF32 ae0 = Load(df32, vec_aligned + i); + const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2)); + Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0); + i += N; + sum0 = hn::MulAdd(ae0, be0, sum0); + sum1 = hn::MulAdd(ao0, bo0, sum1); + + const VF32 ae1 = Load(df32, vec_aligned + i); + const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2)); + Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1); + i += N; + sum2 = hn::MulAdd(ae1, be1, sum2); + sum3 = hn::MulAdd(ao1, bo1, sum3); + } + + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df32, sum0); + } }; template <> struct CompressTraits { using MatT = SfpStream; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -273,6 +338,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = NuqStream; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -425,16 +491,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs, } // Returns dot product with `vec_aligned` of length `num`. -template +template HWY_INLINE float Dot(DF df, const CompressedArray& compressed, size_t compressed_ofs, const VecT* vec_aligned, size_t num) { HWY_DASSERT(compressed_ofs + num <= compressed.size()); HWY_DASSERT(hn::IsAligned(df, vec_aligned)); using Traits = CompressTraits; - return (compressed.scale() * Traits::Dot(df, compressed.size(), - compressed.data(), compressed_ofs, - vec_aligned, num)); + float dot_result; + if constexpr (kVecEO) { + dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs, + vec_aligned, num); + } else { + dot_result = Traits::Dot(df, compressed.size(), compressed.data(), + compressed_ofs, vec_aligned, num); + } + return compressed.scale() * dot_result; } // Callback used by ForeachTensor. diff --git a/gemma/ops.h b/gemma/ops.h index da6a38e..9fa79af 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -92,6 +92,43 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } +HWY_INLINE void ToEvenOddF32( + const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, + float* HWY_RESTRICT out) { + const hn::ScalableTag df; + const hn::Repartition dbf16; + const hn::RebindToUnsigned du32; + const auto odd = Set(du32, 0xFFFF0000u); + using VF32 = decltype(hn::Zero(df)); + + HWY_DASSERT(size % hn::Lanes(dbf16) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + VF32 veven, vodd; + for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) { + Bf16ToF32EO(df, vec_aligned + i, veven, vodd); + hn::Store(veven, df, out + i); + hn::Store(vodd, df, out + i + hn::Lanes(df)); + } +} + +HWY_INLINE void ToEvenOddF32( + const float* HWY_RESTRICT vec_aligned, const size_t size, + float* HWY_RESTRICT out) { + const hn::ScalableTag df; + using VF = hn::Vec; + + HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + VF vec0, vec1; + for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) { + hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1); + hn::Store(vec0, df, out + i); + hn::Store(vec1, df, out + i + hn::Lanes(df)); + } +} + // Simple version without tiling nor threading. template @@ -113,12 +150,38 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } +template +HWY_INLINE void MatVecAddLoop( + const CompressedArray& mat, + const size_t mat_ofs, + const VecT* HWY_RESTRICT vec_aligned, + const AddT* HWY_RESTRICT add, + float* HWY_RESTRICT out) { + PROFILER_ZONE("MatVecAddLoop"); + + const hn::ScalableTag df; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { + const size_t row_ofs = mat_ofs + idx_row * kInner; + if constexpr (kAdd) { + out[idx_row] = hwy::ConvertScalarTo(add[idx_row]) + + Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + } else { + out[idx_row] = Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + } + } +} + template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out); + MatVecAddLoop( + mat, mat_ofs, vec_aligned, /*add=*/(VecT*)nullptr, out); } // Simple version without tiling nor threading, but two offsets/outputs. @@ -166,20 +229,21 @@ namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product // of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate // of the tile is r0, c0. -template +template HWY_INLINE void AccumulatePartialDotProducts( DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; - out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } -// Same as above, but sets out[i] to the first partial dot product + -// init (if kInit), which avoids having to zero-initialize and accumulate. -template +// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial +// dot product + init (if kInit), which avoids having to zero-initialize and +// accumulate. +template HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, @@ -191,9 +255,9 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; if constexpr (kInit) { out[idx_row] = hwy::ConvertScalarTo(init[idx_row + r0]) + - Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } else { - out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } } @@ -202,7 +266,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, // horizontal strip of the entire matrix); the result is the full dot product // for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store // into in out[r - r0]. -template +template HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t num_rows, @@ -211,25 +276,27 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, float* HWY_RESTRICT out) { // Tall and skinny: set `out` to the single dot product. if (mat_stride < MaxCols()) { - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, - num_rows, mat_stride, vec_aligned, add, - out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, + 0, num_rows, mat_stride, + vec_aligned, add, out); return; } // We have at least MaxCols, so start by setting `out` to that: - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, - num_rows, MaxCols(), vec_aligned, add, out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, + num_rows, MaxCols(), vec_aligned, + add, out); // For further multiples of MaxCols, accumulate. Remainders handled below. size_t c0 = MaxCols(); for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) { - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, - MaxCols(), vec_aligned, out); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, + num_rows, MaxCols(), vec_aligned, out); } if (c0 < mat_stride) { // Final cols - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, - mat_stride - c0, vec_aligned, out); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, + num_rows, mat_stride - c0, vec_aligned, + out); } } @@ -254,9 +321,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("MatVec.lambda"); const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add, - out + r0); + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + kRowsPerStrip, vec_aligned, + add, out + r0); }); // Remaining rows @@ -264,8 +331,83 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, if (r0 < kOuter) { PROFILER_ZONE("MatVec remainder"); const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - num_rows, vec_aligned, add, out + r0); + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + num_rows, vec_aligned, add, + out + r0); + } +} + +// A specialization of MatVecAdd to float32 vectors which first rearranges the +// vector to even-odd layout. +template ::supports_eo, bool> = true> +HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, + const float* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + PROFILER_ZONE("MatVecAdd"); + + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, + out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); + } +} + +// A specialization of MatVecAdd to bf16 vectors which first rearranges the +// vector to even-odd layout. +template ::supports_eo, bool> = true> +HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, + const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + PROFILER_ZONE("MatVecAdd"); + + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, + out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); } } @@ -273,8 +415,8 @@ template HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - MatVecAdd( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool); + MatVecAdd( + mat, mat_ofs, vec_aligned, /*add=*/(VecT *)nullptr, out, pool); } template @@ -401,12 +543,12 @@ HWY_NOINLINE void TwoMatVecAdd( pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("TwoMatVec.lambda"); const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat0, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add0, - out0 + r0); - detail::FullDotProductsForStrip(df, mat1, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add1, - out1 + r0); + detail::FullDotProductsForStrip( + df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0, + out0 + r0); + detail::FullDotProductsForStrip( + df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1, + out1 + r0); }); // Remaining rows @@ -414,9 +556,9 @@ HWY_NOINLINE void TwoMatVecAdd( if (r0 < kOuter) { PROFILER_ZONE("TwoMatVec remainder"); const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0); - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0); } }