Fix: specialized MatVecAdd was never called.

This commit is contained in:
Sam Kaufman 2024-04-30 14:58:59 -07:00
parent 6a78a23f4c
commit 59ebecce22
1 changed files with 39 additions and 59 deletions

View File

@ -299,6 +299,34 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
}
}
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add, out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
}
}
} // namespace detail
// Stores dot products of rows with `vec_aligned` + add the values from `add`
@ -316,65 +344,17 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<false, kAdd>(df, mat, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned,
add, out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<false, kAdd>(df, mat, mat_ofs, kInner, r0,
num_rows, vec_aligned, add,
out + r0);
}
}
// A specialization of MatVecAdd to float32 vectors which first rearranges the
// vector to even-odd layout.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT,
std::enable_if_t<
std::is_same_v<VecT, float> || std::is_same_v<VecT, hwy::bfloat16_t>>
= true,
std::enable_if_t<
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
= true>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<true, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<true, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0);
if constexpr (
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
) {
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_dequant.get(), add, out, pool);
} else {
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, out, pool);
}
}