diff --git a/gemma/ops.h b/gemma/ops.h index 1b1c29e..3d1867f 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -339,51 +339,15 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // A specialization of MatVecAdd to float32 vectors which first rearranges the // vector to even-odd layout. template || std::is_same_v> + = true, std::enable_if_t< CompressTraits::kSupportsEvenOdd, bool> = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const float* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - PROFILER_ZONE("MatVecAdd"); - - const hn::ScalableTag df; - constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); - - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add, - out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0); - } -} - -// A specialization of MatVecAdd to bf16 vectors which first rearranges the -// vector to even-odd layout. -template ::kSupportsEvenOdd, bool> - = true> -HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, - const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, + const VecT* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, float* HWY_RESTRICT out, hwy::ThreadPool& pool) { PROFILER_ZONE("MatVecAdd");