mirror of https://github.com/google/gemma.cpp.git
Abstracted some MatVecAdd spec. dupes.
This commit is contained in:
parent
f608337fef
commit
6a78a23f4c
46
gemma/ops.h
46
gemma/ops.h
|
|
@ -339,51 +339,15 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
// A specialization of MatVecAdd to float32 vectors which first rearranges the
|
// A specialization of MatVecAdd to float32 vectors which first rearranges the
|
||||||
// vector to even-odd layout.
|
// vector to even-odd layout.
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename AddT,
|
typename VecT, typename AddT,
|
||||||
|
std::enable_if_t<
|
||||||
|
std::is_same_v<VecT, float> || std::is_same_v<VecT, hwy::bfloat16_t>>
|
||||||
|
= true,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
||||||
= true>
|
= true>
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const float* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
|
||||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
|
||||||
PROFILER_ZONE("MatVecAdd");
|
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
|
||||||
|
|
||||||
const auto vec_dequant = hwy::AllocateAligned<float>(kInner);
|
|
||||||
ToEvenOddF32(vec_aligned, kInner, vec_dequant.get());
|
|
||||||
|
|
||||||
// For each entire strip.
|
|
||||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
|
||||||
PROFILER_ZONE("MatVec.lambda");
|
|
||||||
const size_t r0 = strip * kRowsPerStrip;
|
|
||||||
detail::FullDotProductsForStrip<true, kAdd>(
|
|
||||||
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_dequant.get(), add,
|
|
||||||
out + r0);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remaining rows
|
|
||||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
|
||||||
if (r0 < kOuter) {
|
|
||||||
PROFILER_ZONE("MatVec remainder");
|
|
||||||
const size_t num_rows = kOuter - r0;
|
|
||||||
detail::FullDotProductsForStrip<true, kAdd>(
|
|
||||||
df, mat, mat_ofs, kInner, r0, num_rows, vec_dequant.get(), add, out + r0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A specialization of MatVecAdd to bf16 vectors which first rearranges the
|
|
||||||
// vector to even-odd layout.
|
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
|
||||||
typename AddT,
|
|
||||||
std::enable_if_t<
|
|
||||||
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
|
||||||
= true>
|
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned,
|
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("MatVecAdd");
|
PROFILER_ZONE("MatVecAdd");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue