diff --git a/gemma/ops.h b/gemma/ops.h index 3d1867f..2b5dc39 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -299,6 +299,34 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, } } +template +HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs, + const VecT* HWY_RESTRICT const vec_aligned, + const AddT* HWY_RESTRICT const add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add, out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip( + df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0); + } +} + } // namespace detail // Stores dot products of rows with `vec_aligned` + add the values from `add` @@ -316,65 +344,17 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - kRowsPerStrip, vec_aligned, - add, out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, - num_rows, vec_aligned, add, - out + r0); - } -} - -// A specialization of MatVecAdd to float32 vectors which first rearranges the -// vector to even-odd layout. -template || std::is_same_v> - = true, - std::enable_if_t< - CompressTraits::kSupportsEvenOdd, bool> - = true> -HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - PROFILER_ZONE("MatVecAdd"); - - const hn::ScalableTag df; - constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); - - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, - out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); + if constexpr ( + CompressTraits::kSupportsEvenOdd + && hwy::IsSameEither() + ) { + const auto vec_dequant = hwy::AllocateAligned(kInner); + ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + detail::MatVecAddInner( + mat, mat_ofs, vec_dequant.get(), add, out, pool); + } else { + detail::MatVecAddInner( + mat, mat_ofs, vec_aligned, add, out, pool); } }