Replace last ConstMat with MatPtr

This is to reduce the number of MatMul overloads in preparation for de-templatizing.

PiperOrigin-RevId: 758288589
This commit is contained in:
Jan Wassenberg 2025-05-13 10:54:48 -07:00 committed by Copybara-Service
parent 0a6a7e4cd6
commit 38a08d8095
5 changed files with 42 additions and 115 deletions

View File

@ -328,20 +328,20 @@ class GemmaAttention {
// Computes Q.K scores, which are "logits" (or scores) stored to att. // Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const ConstMat<float>& k, const float* HWY_RESTRICT q, const MatPtrT<float>& k,
float* HWY_RESTRICT att) { float* HWY_RESTRICT att) {
const size_t qkv_dim = layer_config_.qkv_dim; const size_t qkv_dim = layer_config_.qkv_dim;
if (HWY_LIKELY(last_pos < activations_.seq_len)) { if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos); const float* HWY_RESTRICT k_ptr = k.Row(pos);
const float score = Dot(q, k_ptr, qkv_dim); const float score = Dot(q, k_ptr, qkv_dim);
att[pos] = score; att[pos] = score;
} }
} else { } else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t cache_pos = div_seq_len_.Remainder(pos);
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos); const float* HWY_RESTRICT k_ptr = k.Row(cache_pos);
const float score = Dot(q, k_ptr, qkv_dim); const float score = Dot(q, k_ptr, qkv_dim);
att[pos % activations_.seq_len] = score; att[pos % activations_.seq_len] = score;
} }
@ -354,7 +354,7 @@ class GemmaAttention {
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT att, const float* HWY_RESTRICT att,
const ConstMat<float>& v, const MatPtrT<float>& v,
float* HWY_RESTRICT att_out) const { float* HWY_RESTRICT att_out) const {
const size_t qkv_dim = layer_config_.qkv_dim; const size_t qkv_dim = layer_config_.qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
@ -362,13 +362,13 @@ class GemmaAttention {
if (HWY_LIKELY(last_pos < activations_.seq_len)) { if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos); const float* HWY_RESTRICT v_ptr = v.Row(pos);
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim); MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
} }
} else { } else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t cache_pos = div_seq_len_.Remainder(pos);
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos); const float* HWY_RESTRICT v_ptr = v.Row(cache_pos);
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out, MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
qkv_dim); qkv_dim);
} }
@ -378,7 +378,7 @@ class GemmaAttention {
public: public:
// Calculates the attention outputs for a single q. // Calculates the attention outputs for a single q.
HWY_INLINE void SingleDotSoftmaxWeightedSum( HWY_INLINE void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v, float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
const float query_scale, const size_t pos, const size_t start_pos, const float query_scale, const size_t pos, const size_t start_pos,
const size_t last_pos) { const size_t last_pos) {
@ -432,13 +432,14 @@ class GemmaAttention {
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
const size_t kv_head_offset = const size_t kv_head_offset =
layer_ * cache_layer_size_ + head_offset; layer_ * cache_layer_size_ + head_offset;
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset, MatPtrT<float> k("k_view",
Extents2D(kv_cache.seq_len, qkv_dim), Extents2D(kv_cache.seq_len, qkv_dim));
/*stride=*/cache_pos_size_); k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
ConstMat<float> v( /*stride=*/cache_pos_size_);
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, MatPtrT<float> v("v_view",
Extents2D(kv_cache.seq_len, qkv_dim), Extents2D(kv_cache.seq_len, qkv_dim));
/*stride=*/cache_pos_size_); v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
/*stride=*/cache_pos_size_);
// Find the token position in the query and calculate the range // Find the token position in the query and calculate the range
// of cache positions to attend to. // of cache positions to attend to.

View File

@ -115,8 +115,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and // Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling // spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op. // BindB/C if there is a single package: they will be a no-op.
BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans), BindB(allocator, sizeof(TC), b_trans, env.parallel);
env.parallel);
BindC(allocator, A_extents.rows, C, env.parallel); BindC(allocator, A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;

View File

@ -882,20 +882,17 @@ class MMPerPackage {
// B is decompressed several call layers lower, but not all member functions // B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class. // depend on TB, so pass it as an argument instead of templating the class.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_NOINLINE void operator()(const ConstMat<TB>& B, HWY_NOINLINE void operator()(const MatPtrT<TB>& B,
const RowPtr<TC>& C) const { const RowPtr<TC>& C) const {
// TODO: include NUQ tables? NumPacked in ConstMat?
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT(B, num_packed_B, C); return DoNT(B, C);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K(B, num_packed_B, C); return DoNT_K(B, C);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT(B, num_packed_B, C); return DoNT_MT(B, C);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K(B, num_packed_B, C); return DoNT_MT_K(B, C);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
@ -916,8 +913,7 @@ class MMPerPackage {
// Single M and K, parallel N. Fills all of C directly. // Single M and K, parallel N. Fills all of C directly.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B, HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT", args_); zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -941,7 +937,7 @@ class MMPerPackage {
{ {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT.DecB", args_); zone.MaybeEnter("MM.NT.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_K, B_view); DecompressB(B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C); args_, C);
@ -953,8 +949,7 @@ class MMPerPackage {
// Single M, parallel N, sequential K. Fills all of partial. // Single M, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B, HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_K", args_); zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -977,7 +972,7 @@ class MMPerPackage {
{ {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_K.DecB", args_); zone.MaybeEnter("MM.NT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view); DecompressB(B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C); C);
@ -1018,8 +1013,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K. // Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B, HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT", args_); zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -1042,7 +1036,7 @@ class MMPerPackage {
{ {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT.DecB", args_); zone.MaybeEnter("MM.NT_MT.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_K, B_view); DecompressB(B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C); args_, C);
@ -1055,8 +1049,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B, HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT_K", args_); zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
@ -1078,7 +1071,7 @@ class MMPerPackage {
{ {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT_K.DecB", args_); zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view); DecompressB(B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C); C);
@ -1210,18 +1203,18 @@ class MMPerPackage {
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets. // thanks to its large table lookups, and less so on other targets.
template <typename TB> template <typename TB>
HWY_INLINE void DecompressB(const ConstMat<TB>& B, size_t num_packed_B, HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const size_t row_b, const IndexRange& range_kc, const IndexRange& range_kc,
const RowPtrBF& B_view) const { const RowPtrBF& B_view) const {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const PackedSpan<const TB> B_span = MakeSpan(B.ptr, num_packed_B); const PackedSpan<const TB> B_span = B.PaddedSpan();
const size_t kc = range_kc.Num(); const size_t kc = range_kc.Num();
const size_t col0 = range_kc.begin(); const size_t col0 = range_kc.begin();
for (size_t r = 0; r < kNR; ++r) { for (size_t r = 0; r < kNR; ++r) {
const size_t packed_ofs = B.Row(row_b + r) + col0; const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
BF16* HWY_RESTRICT to = B_view.Row(r); BF16* HWY_RESTRICT to = B_view.Row(r);
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
// Verify that we zero-padded. // Verify that we zero-padded.
@ -1264,7 +1257,7 @@ struct MMImpl {
// Called from `MatMul` from two places: either with the next autotune config, // Called from `MatMul` from two places: either with the next autotune config,
// or with the best config. // or with the best config.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B, static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const RowPtr<TC>& C, const MMArgs& args, const RowPtr<TC>& C, const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
MMZone matmul_zone; MMZone matmul_zone;
@ -1300,7 +1293,7 @@ struct MMImpl {
// //
// Uses considerable stack space: at least 40 KiB per thread. // Uses considerable stack space: at least 40 KiB per thread.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) { const RowPtr<TC>& C) {
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
@ -1327,8 +1320,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
MMPerKey& per_key = env.per_key[index]; MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune; MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add, const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
env.storage.Partial()); add, env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) { if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
return &per_key; return &per_key;
@ -1383,13 +1376,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
return &per_key; return &per_key;
} }
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
return MatMul(A, ConstMat<TB>(B), add, env, C);
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -681,67 +681,18 @@ struct MMZone {
}; };
#endif // PROFILER_ENABLED #endif // PROFILER_ENABLED
// Used for the A and B arguments of `MatMul`, which are always const.
// This differs from `RowPtr` in supporting the `ofs` required for compressed T.
// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
template <typename T>
struct ConstMat {
ConstMat() = default;
ConstMat(const T* ptr, Extents2D extents, size_t stride)
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols);
}
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
ConstMat(const MatPtrT<T>& m)
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
scale = m.Scale();
}
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
}
}
return ofs + r * stride;
}
const Extents2D& Extents() const { return extents; }
size_t Stride() const { return stride; }
float Scale() const { return scale; }
// So that matvec-inl.h can use the same interface as MatPtrT:
size_t Rows() const { return extents.rows; }
size_t Cols() const { return extents.cols; }
const T* HWY_RESTRICT ptr;
Extents2D extents;
size_t stride;
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`.
float scale = 1.0f;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic. This is in units of weights, and does not have anything
// to do with the interleaved NUQ tables. It should be computed via `Row()`
// to take into account the stride.
size_t ofs;
};
template <typename TB> template <typename TB>
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC, void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT<TB>& B,
const ConstMat<TB>& B, MMParallel& parallel) { MMParallel& parallel) {
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
const IndexRangePartition ranges_np = const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR); parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
const size_t quantum = allocator.Quantum<TB>(); const size_t quantum = allocator.Quantum<TB>();
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx); const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx); const size_t node = parallel.Node(pkg_idx);
uintptr_t begin = uintptr_t begin = reinterpret_cast<uintptr_t>(B.Row(rows_b.begin()));
reinterpret_cast<uintptr_t>(B.ptr + B.Row(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB); uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
// B is not yet guaranteed to have padded rows, so only bind the // B is not yet guaranteed to have padded rows, so only bind the
// subset that is page-aligned. // subset that is page-aligned.

View File

@ -47,16 +47,6 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Adapter so that gemma-inl.h can pass ConstMat.
// TODO: remove after changing ComputeQKV to MatMul.
template <typename WT, typename VT>
HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
size_t num) {
const hn::ScalableTag<VT> d;
HWY_DASSERT(num <= w.Stride()); // Single row, else padding is an issue.
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
}
// For callers that pass `MatPtrT`, which is not necessarily packed - callers // For callers that pass `MatPtrT`, which is not necessarily packed - callers
// should use Stride() to compute `w_ofs`. // should use Stride() to compute `w_ofs`.
template <typename WT, typename VT> template <typename WT, typename VT>
@ -66,7 +56,7 @@ HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num); return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
} }
// ArrayT is either MatPtrT or ConstMat. // ArrayT is MatPtrT.
// Simple version without tiling nor threading, but two offsets/outputs and // Simple version without tiling nor threading, but two offsets/outputs and
// always with addition. // always with addition.