gemma.cpp/ops/matmul-inl.h

513 lines
22 KiB
C++

// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h"
#include "util/basics.h"
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of
// transposed B. Called in a loop over col_ab. No bounds checking because
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
template <size_t kRow, typename MatTB>
class BRow {
static_assert(kRow < kRegRows); // which unrolled instance we are
public:
BRow(const ConstMat<MatTB>& B, size_t row_b)
: B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())),
B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {}
template <class DR, class VR = hn::Vec<DR>>
HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const {
Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
}
private:
PackedSpan<const MatTB> B_;
const size_t B_ofs_;
};
// Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies
// element-wise with `kRegRows` x 2 row vectors from transposed B, and adds
// them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a
// subset of the terms of the dot products that make up the MatMul result at
// `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`.
//
// This approach is atypical because it requires a horizontal sum, for which we
// introduce a fast and new(?) vector-length agnostic 'transpose', see
// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and
// multiply with one element from N columns in B to obtain N columns of C.
// This is a poor fit for our setting:
// - `Decompress2` decompresses two vectors at a time;
// - B is column-major, so unit-stride SIMD loads return a column, not values
// from different columns, i.e. a row.
// - `ReorderWidenMulAccumulate` is important for bf16 performance, but its
// pairwise adds would add together unrelated terms.
// The first two could be fixed in a packing stage, which is not implemented
// yet, and might not be necessary otherwise. The third seems a fundamental
// mismatch. However, pairwise adds are fine in our setting because C lanes are
// the terms of a single dot product, which can be reordered or pre-reduced.
template <size_t kRow, typename MatTA>
class ALoadAccumulate {
public:
static_assert(kRow < kRegRows); // which unrolled instance we are
// `First` and `Next` handle a single row of A, so the horizontal sums of
// their `C0..3` are the (partial) dot products for 4 consecutive values in
// one row of C.
static_assert(kRegCols == 4);
ALoadAccumulate(const ConstMat<MatTA>& A, size_t row_ac)
: A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())),
A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {}
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
HWY_INLINE void First(DM dm, //
const VM b00, const VM b01, const VM b10, const VM b11,
const VM b20, const VM b21, const VM b30, const VM b31,
VM& C0, VM& C1, VM& C2, VM& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
if constexpr (kRow < kNumRows) {
VM a0, a1;
Decompress2(dm, A_, A_ofs_, a0, a1);
static_assert(kRegCols == 4);
C0 = hn::Mul(a0, b00);
C1 = hn::Mul(a0, b10);
C2 = hn::Mul(a0, b20);
C3 = hn::Mul(a0, b30);
C0 = hn::MulAdd(a1, b01, C0);
C1 = hn::MulAdd(a1, b11, C1);
C2 = hn::MulAdd(a1, b21, C2);
C3 = hn::MulAdd(a1, b31, C3);
}
}
// Same as above, only called if MulT == BF16.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
class VF = hn::Vec<DF>>
HWY_INLINE void First(DM dm, //
const VM b00, const VM b01, const VM b10, const VM b11,
const VM b20, const VM b21, const VM b30, const VM b31,
VF& C0, VF& C1, VF& C2, VF& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
if constexpr (kRow < kNumRows) {
VM a0, a1;
Decompress2(dm, A_, A_ofs_, a0, a1);
const DF df;
static_assert(kRegCols == 4);
C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
if constexpr (HWY_NATIVE_DOT_BF16) {
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
VF unused_sum1 = hn::Zero(df);
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
} else {
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
}
}
}
// Non-first iteration: accumulate into C0..3.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
const VM b10, const VM b11, const VM b20, const VM b21,
const VM b30, const VM b31, VM& C0, VM& C1, VM& C2,
VM& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
if constexpr (kRow < kNumRows) {
VM a0, a1;
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
static_assert(kRegCols == 4);
C0 = hn::MulAdd(a0, b00, C0);
C1 = hn::MulAdd(a0, b10, C1);
C2 = hn::MulAdd(a0, b20, C2);
C3 = hn::MulAdd(a0, b30, C3);
C0 = hn::MulAdd(a1, b01, C0);
C1 = hn::MulAdd(a1, b11, C1);
C2 = hn::MulAdd(a1, b21, C2);
C3 = hn::MulAdd(a1, b31, C3);
}
}
// Same as above, only called if MulT == BF16.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
class VF = hn::Vec<DF>>
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
const VM b10, const VM b11, const VM b20, const VM b21,
const VM b30, const VM b31, VF& C0, VF& C1, VF& C2,
VF& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
if constexpr (kRow < kNumRows) {
VM a0, a1;
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
const DF df;
static_assert(kRegCols == 4);
if constexpr (HWY_NATIVE_DOT_BF16) {
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
VF unused_sum1 = hn::Zero(df);
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
} else {
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30));
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
}
}
}
private:
PackedSpan<const MatTA> A_;
const size_t A_ofs_;
}; // ALoadAccumulate
// Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd,
// otherwise 0.
// `add` has no scale and is a row vector with A.cols entries if `kAdd`,
// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB,
// hence we pass it as a separate argument.
template <size_t kNumRows, bool kAdd>
HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs,
float* HWY_RESTRICT pos_c, size_t stride_c) {
const hn::FixedTag<float, kRegCols> d4;
for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) {
if constexpr (kAdd) {
hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c);
} else {
hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c);
}
}
}
// Accumulates into a tile of C.
template <size_t kNumRows>
class AddHorizontalSums {
// These helper functions hoist if() out of the main code below. They have no
// effect if kRow >= kNumRows.
template <size_t kRow, class DF, class VF = hn::Vec<DF>>
static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2,
VF Cr3, float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N);
}
}
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
return hn::Load(df, buf + 4 * kRow * N);
} else {
return hn::Zero(df);
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N));
} else {
return sum;
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c,
const size_t stride_c) {
if constexpr (kRow < kNumRows) {
const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c);
hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c);
}
}
public:
// Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose
// top left is `tile_c`, after multiplying by `scale`, which is the product of
// the scales of A and B. C is always f32 to ensure sufficient precision.
//
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
// B column vector indexed by `c`. Their elements are thus a subset of the
// terms of the dot product constituting the final `C[r, c]` result. Thus we
// compute the horizontal sums of each `Crc`. The elements may be permuted
// because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does
// not change their horizontal sum. `buf` is thread-local space for 16 `VF`.
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void operator()(DF df, float scale, //
VF C00, VF C01, VF C02, VF C03, //
VF C10, VF C11, VF C12, VF C13, //
VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, //
float* HWY_RESTRICT buf,
float* HWY_RESTRICT tile_c,
size_t stride_c) const {
const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
// log(N) operations for vectors of length N. Because kRegCols == 4, we can
// instead use `StoreInterleaved4` for a vector length-agnostic 'transpose':
// `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0],
// C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1].
MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf);
MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf);
MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf);
MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf);
// Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3
// in the elements of one V4. We have four independent rows `r`, hence the
// code is effectively unrolled, which increases throughput.
const hn::FixedTag<float, 4> d4;
using V4 = hn::Vec<decltype(d4)>;
V4 sum0 = MaybeLoad<0>(d4, N, buf);
V4 sum1 = MaybeLoad<1>(d4, N, buf);
V4 sum2 = MaybeLoad<2>(d4, N, buf);
V4 sum3 = MaybeLoad<3>(d4, N, buf);
for (size_t i = 1; i < N; ++i) {
sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i);
sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i);
sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i);
sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i);
}
// Scale, then store to four elements per row of `tile_c`.
const V4 vscale = hn::Set(d4, scale);
MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c);
MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c);
MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c);
MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c);
}
};
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
// `buf` is 16 vectors of thread-local storage.
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
HWY_INLINE void MatMulTile(const ConstMat<MatTA>& A, const size_t row_ac,
const ConstMat<MatTB>& B, const size_t row_b_col_c,
const float scale, const float* HWY_RESTRICT add,
float* HWY_RESTRICT buf, const RowPtr<float>& C) {
// Decompress A and B to which type, which will then be widened to f32,
// multiplied, added once into f32, then promoted to f64 and accumulated.
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
// more efficient than f32 * f32 + f32 because they process twice as many
// lanes at a time. If available, we definitely want to use them. Otherwise,
// bf16 is still worthwhile if A (activations) are bf16: SFP weights are
// cheaper to decode to bf16, relative to the minor extra cost of promoting
// bf16 when multiplying. However, if A is f32, demoting to bf16 can be
// expensive unless we also have native bf16 dot.
using Raw = hwy::If<HWY_NATIVE_DOT_BF16 || !IsF32<MatTA>(), BF16, float>;
const hn::ScalableTag<Raw> dr;
using VR = hn::Vec<decltype(dr)>;
const size_t NR = hn::Lanes(dr);
const Range1D cols_ab(0, A.Extents().cols);
HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows);
HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows);
HWY_DASSERT(cols_ab.end() % (2 * NR) == 0);
static_assert(kRegRows == 4);
const BRow<0, MatTB> b_row0(B, row_b_col_c);
const BRow<1, MatTB> b_row1(B, row_b_col_c);
const BRow<2, MatTB> b_row2(B, row_b_col_c);
const BRow<3, MatTB> b_row3(B, row_b_col_c);
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
const hn::Repartition<float, decltype(dr)> df;
using VF = hn::Vec<decltype(df)>;
VF C00, C01, C02, C03;
VF C10, C11, C12, C13;
VF C20, C21, C22, C23;
VF C30, C31, C32, C33;
size_t col_ab = cols_ab.begin();
{ // First iteration initializes the `Crc` vectors.
VR b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dr, col_ab, b00, b01);
b_row1.Load2(dr, col_ab, b10, b11);
b_row2.Load2(dr, col_ab, b20, b21);
b_row3.Load2(dr, col_ab, b30, b31);
a_row0.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C00, C01, C02, C03);
a_row1.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C10, C11, C12, C13);
a_row2.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C20, C21, C22, C23);
a_row3.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C30, C31, C32, C33);
col_ab += 2 * NR;
}
// `2 * NR` per iteration because `Load2` returns two vectors.
HWY_UNROLL(1)
for (; col_ab < cols_ab.end(); col_ab += 2 * NR) {
VR b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dr, col_ab, b00, b01);
b_row1.Load2(dr, col_ab, b10, b11);
b_row2.Load2(dr, col_ab, b20, b21);
b_row3.Load2(dr, col_ab, b30, b31);
a_row0.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C00, C01, C02, C03);
a_row1.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C10, C11, C12, C13);
a_row2.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C20, C21, C22, C23);
a_row3.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C30, C31, C32, C33);
}
// TODO: hoist into outer loop.
float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c;
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.Stride());
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
C13, C20, C21, C22, C23, C30, C31, C32, C33,
buf, C_tile, C.Stride());
}
template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMulImpl(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<float>& C) {
// PROFILER_ZONE("Matmul");
HWY_DASSERT(A.Extents().cols == B.Extents().cols);
const size_t batch_size = A.Extents().rows;
HWY_DASSERT(C.Cols() % kRegCols == 0);
HWY_DASSERT(C.Stride() >= C.Cols());
HWY_DASSERT(B.Extents().rows == C.Cols());
const float scale = A.scale * B.scale;
// We currently write C directly, which touches more memory than fits in L3.
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = C.Cols() / kRegCols;
env.Pool().Run(
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
// TODO: when using PerClusterPool, compute lp from outer and inner.
float* HWY_RESTRICT buf = env.Buf().Batch(thread);
const size_t tx = idx_tile % tilesX;
const size_t ty = idx_tile / tilesX;
const size_t row_ac = ty * kRegRows;
const size_t row_b_col_c = tx * kRegCols;
// How many rows of C are left to compute. If more than 4, this
// tile still only computes 4 rows.
const size_t num_rows = batch_size - row_ac;
HWY_DASSERT(num_rows != 0);
switch (num_rows) {
case 1:
MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
break;
case 2:
MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
break;
case 3:
MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
break;
default:
MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
}
});
}
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
//
// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`,
// which must match `A.Extents().cols`, is the number of rows in the original B.
//
// If `add` is non-null, the row-vector `add` is added to each row of `C`.
// A scale for `add` is not supported, so make sure its scale is 1.
//
// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for
// arbitrary strides.
//
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are
// 3k or 24k. Must not be called concurrently with the same `env`.
template <typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<float>& C) {
if (add) {
MatMulImpl<true>(A, B, add, env, C);
} else {
MatMulImpl<false>(A, B, nullptr, env, C);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT