Abstracted some MatVecAdd spec. dupes.

This commit is contained in:
Sam Kaufman 2024-04-29 16:23:38 -07:00
parent f608337fef
commit 6a78a23f4c
1 changed files with 5 additions and 41 deletions

View File

@ -339,51 +339,15 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
// A specialization of MatVecAdd to float32 vectors which first rearranges the // A specialization of MatVecAdd to float32 vectors which first rearranges the
// vector to even-odd layout. // vector to even-odd layout.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename AddT, typename VecT, typename AddT,
std::enable_if_t<
std::is_same_v<VecT, float> || std::is_same_v<VecT, hwy::bfloat16_t>>
= true,
std::enable_if_t< std::enable_if_t<
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool> CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
= true> = true>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const float* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<true, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<true, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0);
}
}
// A specialization of MatVecAdd to bf16 vectors which first rearranges the
// vector to even-odd layout.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename AddT,
std::enable_if_t<
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
= true>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add, const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd"); PROFILER_ZONE("MatVecAdd");