mirror of https://github.com/google/gemma.cpp.git
Fix: specialized MatVecAdd was never called.
This commit is contained in:
parent
6a78a23f4c
commit
59ebecce22
98
gemma/ops.h
98
gemma/ops.h
|
|
@ -299,6 +299,34 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
|||
}
|
||||
}
|
||||
|
||||
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
|
||||
typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add, out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Stores dot products of rows with `vec_aligned` + add the values from `add`
|
||||
|
|
@ -316,65 +344,17 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<false, kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned,
|
||||
add, out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<false, kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
num_rows, vec_aligned, add,
|
||||
out + r0);
|
||||
}
|
||||
}
|
||||
|
||||
// A specialization of MatVecAdd to float32 vectors which first rearranges the
|
||||
// vector to even-odd layout.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT,
|
||||
std::enable_if_t<
|
||||
std::is_same_v<VecT, float> || std::is_same_v<VecT, hwy::bfloat16_t>>
|
||||
= true,
|
||||
std::enable_if_t<
|
||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
||||
= true>
|
||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVecAdd");
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
||||
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
|
||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<true, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add,
|
||||
out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<true, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0);
|
||||
if constexpr (
|
||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd
|
||||
&& hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()
|
||||
) {
|
||||
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
|
||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
||||
detail::MatVecAddInner<true, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_dequant.get(), add, out, pool);
|
||||
} else {
|
||||
detail::MatVecAddInner<false, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue