From 0816a1070d3f6ca4e8bfe63b71732b29cea1f575 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Sun, 28 Apr 2024 17:53:16 -0700 Subject: [PATCH 1/8] Even-odd layout MatVecs for bf16 weights. --- compression/compress-inl.h | 80 +++++++++++++- gemma/ops.h | 210 +++++++++++++++++++++++++++++++------ 2 files changed, 252 insertions(+), 38 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index a6a4b7e..631dbd4 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,6 +51,20 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +template +HWY_INLINE void Bf16ToF32EO(const DF df32, + const hwy::bfloat16_t* HWY_RESTRICT in, + hn::Vec& v_even, + hn::Vec& v_odd) { + const hn::Repartition dbf16; + const hn::RebindToUnsigned du32; + + const auto odd = Set(du32, 0xFFFF0000u); + const auto interleaved = BitCast(du32, LoadU(dbf16, in)); + v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved)); + v_odd = BitCast(df32, And(interleaved, odd)); +} + // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -58,6 +72,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -111,6 +126,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = hwy::bfloat16_t; + static constexpr bool supports_eo = true; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -219,11 +235,60 @@ struct CompressTraits { // bf16*bf16. return hn::Dot::Compute(d_vec, vec_aligned, in + in_ofs, num); } + + // Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned` + // and a column- major matrix `in`. `vec_aligned` should be aligned and + // alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed + // `hn::Lanes(df32)` elements. + template + static HWY_INLINE float DotEO( + const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs, + const float* HWY_RESTRICT vec_aligned, size_t num) { + HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0); + HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0); + HWY_DASSERT(hn::IsAligned(df32, vec_aligned)); + + const hn::Repartition dbf16; + using VF32 = decltype(Zero(df32)); + const size_t N = Lanes(dbf16); + + VF32 sum0 = Zero(df32); + VF32 sum1 = Zero(df32); + VF32 sum2 = Zero(df32); + VF32 sum3 = Zero(df32); + + const hn::RebindToUnsigned du32; + using VU32 = hn::VFromD; + const VU32 odd = Set(du32, 0xFFFF0000u); + + VF32 be0, bo0, be1, bo1; + for (size_t i = 0; i < num; /* i += 2 * N */) { + const VF32 ae0 = Load(df32, vec_aligned + i); + const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2)); + Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0); + i += N; + sum0 = hn::MulAdd(ae0, be0, sum0); + sum1 = hn::MulAdd(ao0, bo0, sum1); + + const VF32 ae1 = Load(df32, vec_aligned + i); + const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2)); + Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1); + i += N; + sum2 = hn::MulAdd(ae1, be1, sum2); + sum3 = hn::MulAdd(ao1, bo1, sum3); + } + + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df32, sum0); + } }; template <> struct CompressTraits { using MatT = SfpStream; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -273,6 +338,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = NuqStream; + static constexpr bool supports_eo = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -425,16 +491,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs, } // Returns dot product with `vec_aligned` of length `num`. -template +template HWY_INLINE float Dot(DF df, const CompressedArray& compressed, size_t compressed_ofs, const VecT* vec_aligned, size_t num) { HWY_DASSERT(compressed_ofs + num <= compressed.size()); HWY_DASSERT(hn::IsAligned(df, vec_aligned)); using Traits = CompressTraits; - return (compressed.scale() * Traits::Dot(df, compressed.size(), - compressed.data(), compressed_ofs, - vec_aligned, num)); + float dot_result; + if constexpr (kVecEO) { + dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs, + vec_aligned, num); + } else { + dot_result = Traits::Dot(df, compressed.size(), compressed.data(), + compressed_ofs, vec_aligned, num); + } + return compressed.scale() * dot_result; } // Callback used by ForeachTensor. diff --git a/gemma/ops.h b/gemma/ops.h index da6a38e..9fa79af 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -92,6 +92,43 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } +HWY_INLINE void ToEvenOddF32( + const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, + float* HWY_RESTRICT out) { + const hn::ScalableTag df; + const hn::Repartition dbf16; + const hn::RebindToUnsigned du32; + const auto odd = Set(du32, 0xFFFF0000u); + using VF32 = decltype(hn::Zero(df)); + + HWY_DASSERT(size % hn::Lanes(dbf16) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + VF32 veven, vodd; + for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) { + Bf16ToF32EO(df, vec_aligned + i, veven, vodd); + hn::Store(veven, df, out + i); + hn::Store(vodd, df, out + i + hn::Lanes(df)); + } +} + +HWY_INLINE void ToEvenOddF32( + const float* HWY_RESTRICT vec_aligned, const size_t size, + float* HWY_RESTRICT out) { + const hn::ScalableTag df; + using VF = hn::Vec; + + HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + + VF vec0, vec1; + for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) { + hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1); + hn::Store(vec0, df, out + i); + hn::Store(vec1, df, out + i + hn::Lanes(df)); + } +} + // Simple version without tiling nor threading. template @@ -113,12 +150,38 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } +template +HWY_INLINE void MatVecAddLoop( + const CompressedArray& mat, + const size_t mat_ofs, + const VecT* HWY_RESTRICT vec_aligned, + const AddT* HWY_RESTRICT add, + float* HWY_RESTRICT out) { + PROFILER_ZONE("MatVecAddLoop"); + + const hn::ScalableTag df; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { + const size_t row_ofs = mat_ofs + idx_row * kInner; + if constexpr (kAdd) { + out[idx_row] = hwy::ConvertScalarTo(add[idx_row]) + + Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + } else { + out[idx_row] = Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + } + } +} + template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out); + MatVecAddLoop( + mat, mat_ofs, vec_aligned, /*add=*/(VecT*)nullptr, out); } // Simple version without tiling nor threading, but two offsets/outputs. @@ -166,20 +229,21 @@ namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product // of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate // of the tile is r0, c0. -template +template HWY_INLINE void AccumulatePartialDotProducts( DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; - out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } -// Same as above, but sets out[i] to the first partial dot product + -// init (if kInit), which avoids having to zero-initialize and accumulate. -template +// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial +// dot product + init (if kInit), which avoids having to zero-initialize and +// accumulate. +template HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, @@ -191,9 +255,9 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; if constexpr (kInit) { out[idx_row] = hwy::ConvertScalarTo(init[idx_row + r0]) + - Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } else { - out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } } @@ -202,7 +266,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, // horizontal strip of the entire matrix); the result is the full dot product // for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store // into in out[r - r0]. -template +template HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t num_rows, @@ -211,25 +276,27 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, float* HWY_RESTRICT out) { // Tall and skinny: set `out` to the single dot product. if (mat_stride < MaxCols()) { - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, - num_rows, mat_stride, vec_aligned, add, - out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, + 0, num_rows, mat_stride, + vec_aligned, add, out); return; } // We have at least MaxCols, so start by setting `out` to that: - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, - num_rows, MaxCols(), vec_aligned, add, out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, + num_rows, MaxCols(), vec_aligned, + add, out); // For further multiples of MaxCols, accumulate. Remainders handled below. size_t c0 = MaxCols(); for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) { - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, - MaxCols(), vec_aligned, out); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, + num_rows, MaxCols(), vec_aligned, out); } if (c0 < mat_stride) { // Final cols - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, - mat_stride - c0, vec_aligned, out); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, + num_rows, mat_stride - c0, vec_aligned, + out); } } @@ -254,9 +321,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("MatVec.lambda"); const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add, - out + r0); + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + kRowsPerStrip, vec_aligned, + add, out + r0); }); // Remaining rows @@ -264,8 +331,83 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, if (r0 < kOuter) { PROFILER_ZONE("MatVec remainder"); const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - num_rows, vec_aligned, add, out + r0); + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + num_rows, vec_aligned, add, + out + r0); + } +} + +// A specialization of MatVecAdd to float32 vectors which first rearranges the +// vector to even-odd layout. +template ::supports_eo, bool> = true> +HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, + const float* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + PROFILER_ZONE("MatVecAdd"); + + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, + out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); + } +} + +// A specialization of MatVecAdd to bf16 vectors which first rearranges the +// vector to even-odd layout. +template ::supports_eo, bool> = true> +HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, + const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + PROFILER_ZONE("MatVecAdd"); + + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, + out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); } } @@ -273,8 +415,8 @@ template HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - MatVecAdd( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool); + MatVecAdd( + mat, mat_ofs, vec_aligned, /*add=*/(VecT *)nullptr, out, pool); } template @@ -401,12 +543,12 @@ HWY_NOINLINE void TwoMatVecAdd( pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("TwoMatVec.lambda"); const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat0, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add0, - out0 + r0); - detail::FullDotProductsForStrip(df, mat1, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, add1, - out1 + r0); + detail::FullDotProductsForStrip( + df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0, + out0 + r0); + detail::FullDotProductsForStrip( + df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1, + out1 + r0); }); // Remaining rows @@ -414,9 +556,9 @@ HWY_NOINLINE void TwoMatVecAdd( if (r0 < kOuter) { PROFILER_ZONE("TwoMatVec remainder"); const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0); - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0); } } From 5cb63346aa7a328f7f0bc3505f78db56c1e5f80a Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Mon, 29 Apr 2024 12:51:35 -0700 Subject: [PATCH 2/8] supports_eo -> kSupportsEvenOdd --- compression/compress-inl.h | 8 ++++---- gemma/ops.h | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 631dbd4..7ae43b7 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -72,7 +72,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -126,7 +126,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = hwy::bfloat16_t; - static constexpr bool supports_eo = true; + static constexpr bool kSupportsEvenOdd = true; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -288,7 +288,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = SfpStream; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -338,7 +338,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = NuqStream; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, diff --git a/gemma/ops.h b/gemma/ops.h index 9fa79af..1c09409 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -341,7 +341,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // vector to even-odd layout. template ::supports_eo, bool> = true> + std::enable_if_t< + CompressTraits::kSupportsEvenOdd, bool> + = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, const float* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, @@ -378,7 +380,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // vector to even-odd layout. template ::supports_eo, bool> = true> + std::enable_if_t< + CompressTraits::kSupportsEvenOdd, bool> + = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, From aa0b113214c8daeb9cf32a2c9bfbb0669406be7b Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Mon, 29 Apr 2024 12:53:47 -0700 Subject: [PATCH 3/8] (VecT*) to static_cast. --- gemma/ops.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 1c09409..0c29cda 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -181,7 +181,7 @@ HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) { MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/(VecT*)nullptr, out); + mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), out); } // Simple version without tiling nor threading, but two offsets/outputs. @@ -420,7 +420,8 @@ HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { MatVecAdd( - mat, mat_ofs, vec_aligned, /*add=*/(VecT *)nullptr, out, pool); + mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), out, + pool); } template From f608337fef6e646c0b9f77809f1d840f7f089f6a Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Mon, 29 Apr 2024 14:13:07 -0700 Subject: [PATCH 4/8] Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo. --- compression/compress-inl.h | 26 ++++++-------------------- gemma/ops.h | 7 +++---- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 7ae43b7..eb87cf8 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,20 +51,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -HWY_INLINE void Bf16ToF32EO(const DF df32, - const hwy::bfloat16_t* HWY_RESTRICT in, - hn::Vec& v_even, - hn::Vec& v_odd) { - const hn::Repartition dbf16; - const hn::RebindToUnsigned du32; - - const auto odd = Set(du32, 0xFFFF0000u); - const auto interleaved = BitCast(du32, LoadU(dbf16, in)); - v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved)); - v_odd = BitCast(df32, And(interleaved, odd)); -} - // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -263,19 +249,19 @@ struct CompressTraits { VF32 be0, bo0, be1, bo1; for (size_t i = 0; i < num; /* i += 2 * N */) { + const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i); const VF32 ae0 = Load(df32, vec_aligned + i); const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2)); - Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0); + sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0); + sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1); i += N; - sum0 = hn::MulAdd(ae0, be0, sum0); - sum1 = hn::MulAdd(ao0, bo0, sum1); + const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i); const VF32 ae1 = Load(df32, vec_aligned + i); const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2)); - Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1); + sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2); + sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3); i += N; - sum2 = hn::MulAdd(ae1, be1, sum2); - sum3 = hn::MulAdd(ao1, bo1, sum3); } sum0 = Add(sum0, sum1); diff --git a/gemma/ops.h b/gemma/ops.h index 0c29cda..1b1c29e 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -104,11 +104,10 @@ HWY_INLINE void ToEvenOddF32( HWY_DASSERT(size % hn::Lanes(dbf16) == 0); HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - VF32 veven, vodd; for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) { - Bf16ToF32EO(df, vec_aligned + i, veven, vodd); - hn::Store(veven, df, out + i); - hn::Store(vodd, df, out + i + hn::Lanes(df)); + const auto interleaved = hn::LoadU(dbf16, vec_aligned + i); + hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i); + hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df)); } } From 6a78a23f4c1ec27ecb6f80251f6a6891f4c71685 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Mon, 29 Apr 2024 16:23:38 -0700 Subject: [PATCH 5/8] Abstracted some MatVecAdd spec. dupes. --- gemma/ops.h | 46 +++++----------------------------------------- 1 file changed, 5 insertions(+), 41 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 1b1c29e..3d1867f 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -339,51 +339,15 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // A specialization of MatVecAdd to float32 vectors which first rearranges the // vector to even-odd layout. template || std::is_same_v> + = true, std::enable_if_t< CompressTraits::kSupportsEvenOdd, bool> = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const float* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - PROFILER_ZONE("MatVecAdd"); - - const hn::ScalableTag df; - constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); - - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, - out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); - } -} - -// A specialization of MatVecAdd to bf16 vectors which first rearranges the -// vector to even-odd layout. -template ::kSupportsEvenOdd, bool> - = true> -HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, + const VecT* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { PROFILER_ZONE("MatVecAdd"); From 59ebecce22cd7380c68aacb3ad0dc108f658ddb4 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Tue, 30 Apr 2024 14:58:59 -0700 Subject: [PATCH 6/8] Fix: specialized MatVecAdd was never called. --- gemma/ops.h | 98 +++++++++++++++++++++-------------------------------- 1 file changed, 39 insertions(+), 59 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 3d1867f..2b5dc39 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -299,6 +299,34 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, } } +template +HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs, + const VecT* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add, out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0); + } +} + } // namespace detail // Stores dot products of rows with `vec_aligned` + add the values from `add` @@ -316,65 +344,17 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, - add, out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - num_rows, vec_aligned, add, - out + r0); - } -} - -// A specialization of MatVecAdd to float32 vectors which first rearranges the -// vector to even-odd layout. -template || std::is_same_v> - = true, - std::enable_if_t< - CompressTraits::kSupportsEvenOdd, bool> - = true> -HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - PROFILER_ZONE("MatVecAdd"); - - const hn::ScalableTag df; - constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); - - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, - out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); + if constexpr ( + CompressTraits::kSupportsEvenOdd + && hwy::IsSameEither() + ) { + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + detail::MatVecAddInner( + mat, mat_ofs, vec_dequant.get(), add, out, pool); + } else { + detail::MatVecAddInner( + mat, mat_ofs, vec_aligned, add, out, pool); } } From 2829ef17ad84e19c797786247f28a32c0863c1c2 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Tue, 30 Apr 2024 15:19:28 -0700 Subject: [PATCH 7/8] Check for HWY_NATIVE_DOT_BF16. --- gemma/ops.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 2b5dc39..1988bd8 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -149,6 +149,8 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } + +#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 template HWY_INLINE void MatVecAddLoop( @@ -174,6 +176,7 @@ HWY_INLINE void MatVecAddLoop( } } } +#endif template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, @@ -344,6 +347,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 if constexpr ( CompressTraits::kSupportsEvenOdd && hwy::IsSameEither() @@ -352,10 +356,12 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); detail::MatVecAddInner( mat, mat_ofs, vec_dequant.get(), add, out, pool); - } else { - detail::MatVecAddInner( - mat, mat_ofs, vec_aligned, add, out, pool); + return; } + #endif + + detail::MatVecAddInner( + mat, mat_ofs, vec_aligned, add, out, pool); } template From 4a6173d929c04701cc08618fe45151bb01638251 Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Thu, 2 May 2024 00:41:44 -0700 Subject: [PATCH 8/8] Remove unused vars. --- gemma/ops.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index c0cceec..bef899e 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -98,7 +98,6 @@ HWY_INLINE void ToEvenOddF32( const hn::ScalableTag df; const hn::Repartition dbf16; const hn::RebindToUnsigned du32; - const auto odd = Set(du32, 0xFFFF0000u); using VF32 = decltype(hn::Zero(df)); HWY_DASSERT(size % hn::Lanes(dbf16) == 0); @@ -361,9 +360,7 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { PROFILER_ZONE("MatVecAdd"); - const hn::ScalableTag df; constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 if constexpr (