mirror of https://github.com/google/gemma.cpp.git
Merge pull request #166 from samkaufman:deinterleave-vecs
PiperOrigin-RevId: 630360778
This commit is contained in:
commit
6eeef2e2d9
|
|
@ -58,6 +58,7 @@ struct CompressTraits {};
|
|||
template <>
|
||||
struct CompressTraits<float> {
|
||||
using MatT = float;
|
||||
static constexpr bool kSupportsEvenOdd = false;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
|
|
@ -111,6 +112,7 @@ struct CompressTraits<float> {
|
|||
template <>
|
||||
struct CompressTraits<hwy::bfloat16_t> {
|
||||
using MatT = hwy::bfloat16_t;
|
||||
static constexpr bool kSupportsEvenOdd = true;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
|
|
@ -219,11 +221,60 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
// bf16*bf16.
|
||||
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
|
||||
}
|
||||
|
||||
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
|
||||
// and a column- major matrix `in`. `vec_aligned` should be aligned and
|
||||
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
|
||||
// `hn::Lanes(df32)` elements.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float DotEO(
|
||||
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
|
||||
const float* HWY_RESTRICT vec_aligned, size_t num) {
|
||||
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
|
||||
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
|
||||
using VF32 = decltype(Zero(df32));
|
||||
const size_t N = Lanes(dbf16);
|
||||
|
||||
VF32 sum0 = Zero(df32);
|
||||
VF32 sum1 = Zero(df32);
|
||||
VF32 sum2 = Zero(df32);
|
||||
VF32 sum3 = Zero(df32);
|
||||
|
||||
const hn::RebindToUnsigned<decltype(df32)> du32;
|
||||
using VU32 = hn::VFromD<decltype(du32)>;
|
||||
const VU32 odd = Set(du32, 0xFFFF0000u);
|
||||
|
||||
VF32 be0, bo0, be1, bo1;
|
||||
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
||||
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
|
||||
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
|
||||
i += N;
|
||||
|
||||
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae1 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
||||
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
|
||||
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
|
||||
i += N;
|
||||
}
|
||||
|
||||
sum0 = Add(sum0, sum1);
|
||||
sum2 = Add(sum2, sum3);
|
||||
sum0 = Add(sum0, sum2);
|
||||
return ReduceSum(df32, sum0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<SfpStream> {
|
||||
using MatT = SfpStream;
|
||||
static constexpr bool kSupportsEvenOdd = false;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||
|
|
@ -273,6 +324,7 @@ struct CompressTraits<SfpStream> {
|
|||
template <>
|
||||
struct CompressTraits<NuqStream> {
|
||||
using MatT = NuqStream;
|
||||
static constexpr bool kSupportsEvenOdd = false;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||
|
|
@ -425,16 +477,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
|
|||
}
|
||||
|
||||
// Returns dot product with `vec_aligned` of length `num`.
|
||||
template <class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, const VecT* vec_aligned,
|
||||
size_t num) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
using Traits = CompressTraits<MatT>;
|
||||
return (compressed.scale() * Traits::Dot(df, compressed.size(),
|
||||
compressed.data(), compressed_ofs,
|
||||
vec_aligned, num));
|
||||
float dot_result;
|
||||
if constexpr (kVecEO) {
|
||||
dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs,
|
||||
vec_aligned, num);
|
||||
} else {
|
||||
dot_result = Traits::Dot(df, compressed.size(), compressed.data(),
|
||||
compressed_ofs, vec_aligned, num);
|
||||
}
|
||||
return compressed.scale() * dot_result;
|
||||
}
|
||||
|
||||
// Callback used by ForeachTensor.
|
||||
|
|
|
|||
228
gemma/ops.h
228
gemma/ops.h
|
|
@ -25,6 +25,7 @@
|
|||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
#include "compression/compress.h" // CompressedArray
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -92,6 +93,96 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
|||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||
const size_t size, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
|
||||
|
||||
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
|
||||
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
||||
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
|
||||
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
|
||||
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
|
||||
}
|
||||
}
|
||||
|
||||
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
|
||||
const size_t size, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
|
||||
VF vec0, vec1;
|
||||
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
|
||||
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
|
||||
hn::Store(vec0, df, out + i);
|
||||
hn::Store(vec1, df, out + i + hn::Lanes(df));
|
||||
}
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading.
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
if constexpr (kAdd) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||
Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
} else {
|
||||
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
|
||||
size_t kCapacity>
|
||||
HWY_INLINE void MatVecAddLoop(
|
||||
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
|
||||
const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
PROFILER_ZONE("MatVecAddLoop");
|
||||
constexpr bool kVecIsEvenOdd = true;
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + idx_row * kInner;
|
||||
if constexpr (kAdd) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
|
||||
Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
|
||||
} else {
|
||||
out[idx_row] = Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// even_odd is precomputed for the current thread.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out) {
|
||||
MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
|
||||
out);
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
|
|
@ -120,25 +211,40 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
}
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||
const size_t mat_ofs1,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
|
||||
// of the tile is r0, c0.
|
||||
template <class DF, typename ArrayT, typename VecT>
|
||||
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void AccumulatePartialDotProducts(
|
||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
||||
size_t c0, size_t num_rows, size_t num_cols,
|
||||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] +=
|
||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, but sets out[i] to the first partial dot product +
|
||||
// init (if kInit), which avoids having to zero-initialize and accumulate.
|
||||
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
|
||||
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
|
||||
// dot product + init (if kInit), which avoids having to zero-initialize and
|
||||
// accumulate.
|
||||
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
|
||||
typename InitT>
|
||||
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||
size_t mat_ofs, size_t mat_stride,
|
||||
size_t r0, size_t c0,
|
||||
|
|
@ -149,10 +255,12 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
|||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
if constexpr (kInit) {
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||
Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] =
|
||||
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
} else {
|
||||
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] =
|
||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -161,7 +269,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
|||
// horizontal strip of the entire matrix); the result is the full dot product
|
||||
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
|
||||
// into in out[r - r0].
|
||||
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
|
||||
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
|
||||
typename AddT>
|
||||
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||
size_t mat_ofs, size_t mat_stride,
|
||||
size_t r0, size_t num_rows,
|
||||
|
|
@ -170,42 +279,37 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
|||
float* HWY_RESTRICT out) {
|
||||
// Tall and skinny: set `out` to the single dot product.
|
||||
if (mat_stride < MaxCols()) {
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||
num_rows, mat_stride, vec_aligned, add,
|
||||
out);
|
||||
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
|
||||
0, num_rows, mat_stride,
|
||||
vec_aligned, add, out);
|
||||
return;
|
||||
}
|
||||
|
||||
// We have at least MaxCols, so start by setting `out` to that:
|
||||
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||
num_rows, MaxCols(), vec_aligned, add, out);
|
||||
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||
num_rows, MaxCols(), vec_aligned,
|
||||
add, out);
|
||||
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
||||
size_t c0 = MaxCols();
|
||||
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
MaxCols(), vec_aligned, out);
|
||||
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
||||
num_rows, MaxCols(), vec_aligned, out);
|
||||
}
|
||||
|
||||
if (c0 < mat_stride) { // Final cols
|
||||
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||
mat_stride - c0, vec_aligned, out);
|
||||
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
||||
num_rows, mat_stride - c0, vec_aligned,
|
||||
out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Stores dot products of rows with `vec_aligned` + add the values from `add`
|
||||
// (if kAdd), then stores them to `out`.
|
||||
// `even_odd` has kInner elements for each thread.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVecAdd");
|
||||
|
||||
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
|
||||
typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT even_odd,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
|
|
@ -223,9 +327,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add,
|
||||
out + r0);
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
|
||||
out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
|
|
@ -233,18 +337,47 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
|||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
num_rows, vec_aligned, add, out + r0);
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Stores dot products of rows with `vec_aligned` + add the values from `add`
|
||||
// (if kAdd), then stores them to `out`.
|
||||
//
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVecAdd");
|
||||
|
||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
|
||||
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
|
||||
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
||||
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, even_odd, add, even_odd, out, pool);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
|
||||
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
|
||||
}
|
||||
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
||||
hwy::ThreadPool& pool) {
|
||||
MatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
|
||||
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out, pool);
|
||||
MatVecAdd</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
|
||||
/*add=*/static_cast<VecT*>(nullptr),
|
||||
even_odd, out, pool);
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
|
|
@ -366,17 +499,18 @@ HWY_NOINLINE void TwoMatVecAdd(
|
|||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
constexpr bool kVecIsEvenOdd = false;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("TwoMatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add0,
|
||||
out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add1,
|
||||
out1 + r0);
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
|
||||
out0 + r0);
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
|
||||
out1 + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
|
|
@ -384,9 +518,9 @@ HWY_NOINLINE void TwoMatVecAdd(
|
|||
if (r0 < kOuter) {
|
||||
PROFILER_ZONE("TwoMatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
||||
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue