MatMul cleanup: Mat struct, simplify args.

Add large benchmark to test, use 4 threads, skip some targets.
Also use Traits::Name instead of typeid.

PiperOrigin-RevId: 657496185
This commit is contained in:
Jan Wassenberg 2024-07-30 01:55:09 -07:00 committed by Copybara-Service
parent d9f86f8e4d
commit 6ea4232b2e
4 changed files with 314 additions and 314 deletions

View File

@ -60,6 +60,7 @@ struct CompressTraits {};
template <>
struct CompressTraits<float> {
using MatT = float;
static const char* Name() { return "f32"; }
static constexpr bool kSupportsEvenOdd = false; // unnecessary
template <class DF, HWY_IF_F32_D(DF)>
@ -123,6 +124,7 @@ struct CompressTraits<float> {
template <>
struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t;
static const char* Name() { return "bf16"; }
static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)>
@ -292,6 +294,7 @@ struct CompressTraits<hwy::bfloat16_t> {
template <>
struct CompressTraits<SfpStream> {
using MatT = SfpStream;
static const char* Name() { return "sfp"; }
static constexpr bool kSupportsEvenOdd = true;
// Callers are responsible for scaling `in` such that its magnitudes do not
@ -389,6 +392,7 @@ struct CompressTraits<SfpStream> {
template <>
struct CompressTraits<NuqStream> {
using MatT = NuqStream;
static const char* Name() { return "nuq"; }
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>

View File

@ -237,12 +237,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
//
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below.
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(),
0, kModelDim, layer_weights->qkv_einsum_w.data(),
0, kHeads * kQStride,
layer_weights->qkv_einsum_w.scale(),
activations.q.All(), kHeads * kQStride,
/*add=*/nullptr, pool);
MatMul_4x4</*kAdd=*/false>(
num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim),
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(activations.q.All(), kHeads * kQStride), pool);
// Compute KV if not MHA.
if constexpr (!kIsMHA) {
@ -250,16 +249,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
// directly into the KV cache with a stride of kCachePosSize.
if (num_queries == 1 &&
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
const size_t colsBC = kKVHeads * 2 * kQKVDim;
const size_t kv_ofs =
batch_start * kCachePosSize + layer * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>(
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim,
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim,
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize,
/*add=*/nullptr, pool);
num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim,
kHeads * kQKVDim * kModelDim),
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool);
} else {
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
@ -441,14 +440,12 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsBC = kFFHiddenDim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All();
const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim);
const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim,
kModelDim, kModelDim * kFFHiddenDim);
const float scale = layer_weights->gating_einsum_w.scale();
const auto B1 = layer_weights->gating_einsum_w.data();
auto C1 = activations.C1.All();
auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases;
const float* bias1 = nullptr;
const float* bias2 = nullptr;
@ -458,24 +455,22 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
bias2 = bias1 + kFFHiddenDim;
output_bias = layer_weights->ffw_output_biases.data_scale1();
}
auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim);
auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim);
const size_t A_ofs = 0; // no offset, using the same activations for both.
// Will go through GELU.
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
MatMul_4x4<kAddBias>(num_interleaved, A, B1, scale, bias1, C1, pool);
// What to multiply by.
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
bias2, pool);
MatMul_4x4<kAddBias>(num_interleaved, A, B2, scale, bias2, C2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer.
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim,
layer_weights->linear_w.data(), 0, kModelDim,
layer_weights->linear_w.scale(),
activations.ffw_out.All(), kModelDim, output_bias, pool);
MatMul_4x4<kAddBias>(num_interleaved, C1,
MakeMat(layer_weights->linear_w.data(), kFFHiddenDim),
layer_weights->linear_w.scale(), output_bias,
MakeMat(activations.ffw_out.All(), kModelDim), pool);
}
// `batch_idx` indicates which row of `x` to write to.
@ -1022,12 +1017,11 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations.
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim,
weights.embedder_input_embedding.data(), 0,
kVocabSize,
weights.embedder_input_embedding.scale(),
activations.logits.All(), kVocabSize,
/*add=*/nullptr, pool);
MatMul_4x4</*kAdd=*/false>(
num_queries, MakeMat(activations.x.All(), kModelDim),
MakeMat(weights.embedder_input_embedding.data(), kModelDim),
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
MakeMat(activations.logits.All(), kVocabSize), pool);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
if constexpr (TConfig::kFinalCap > 0.0f) {

View File

@ -60,7 +60,7 @@ HWY_INLINE void StoreHorizontalSums(DF df, //
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
// expensive, but only a fraction of the A.cols/N FMAs.
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
@ -93,14 +93,14 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add,
const float scale,
const float* HWY_RESTRICT add,
float* HWY_RESTRICT tile_c,
size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
// expensive, but only a fraction of the A.cols/N FMAs.
const float add0 = add[0];
// TODO: 4x4 transpose, then 128-bit vector FMA?
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
@ -137,12 +137,12 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, size_t add_offset, const float scale,
const float scale, const float* HWY_RESTRICT add, size_t add_offset,
float* HWY_RESTRICT tile_c, size_t stride_c) {
if constexpr (kAdd) {
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
add + add_offset, scale, tile_c, stride_c);
scale, add + add_offset, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
@ -150,6 +150,36 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
}
}
// Wrapper to simplify call sites. T can be const or non-const.
template <typename T>
struct Mat {
bool NotEmpty() const {
return ptr != nullptr && cols != 0 && stride >= cols;
}
size_t Row(size_t r) const { return ofs + stride * r; }
T* HWY_RESTRICT ptr;
size_t cols;
// elements between rows, which is typically the same as `cols`.
size_t stride;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
return MakeMat(ptr, cols, cols);
}
#undef GEMMA_NATIVE_BF16
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
defined(HWY_TARGET_TOGGLE))
@ -162,31 +192,18 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
template <size_t kNumRows, bool kAdd>
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
const size_t A_ofs,
const hwy::bfloat16_t* HWY_RESTRICT B,
const size_t B_ofs, float* HWY_RESTRICT C,
HWY_INLINE void MatMulTile(const Mat<const hwy::bfloat16_t>& A,
const Mat<const hwy::bfloat16_t>& B,
const size_t row_a, const size_t row_b_col_c,
const float scale, const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles,
const size_t cols_a, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const Mat<float>& C) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
// bf16 vectors.
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
VF unused_sum1 = hn::Zero(df);
const size_t N = Lanes(d);
VF unused_sum1 = hn::Zero(df);
VF c00 = hn::Zero(df);
VF c01 = hn::Zero(df);
VF c02 = hn::Zero(df);
@ -207,42 +224,41 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
VF c32 = hn::Zero(df);
VF c33 = hn::Zero(df);
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a;
const hwy::bfloat16_t* HWY_RESTRICT B_tile =
B + B_ofs + stride_b * row_b_col_c;
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a);
const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) {
for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) {
using V = hn::Vec<decltype(d)>;
const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab);
const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab);
const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab);
const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab);
const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab);
const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
if constexpr (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab);
const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
if constexpr (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab);
const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
if constexpr (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab);
const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
@ -252,10 +268,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
}
#endif // GEMMA_NATIVE_BF16
@ -277,32 +293,20 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
c3 = hn::MulAdd(a1, b31, c3);
}
// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is
// transposed, so we iterate over both A and B with consecutive vector loads.
// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
// finished tile of `C`.
// General case: uses CompressTraits to load from A and B.
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
float* HWY_RESTRICT C, const float scale,
const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles,
const size_t cols_a, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
using TraitsA = CompressTraits<MatTA>;
using TraitsB = CompressTraits<MatTB>;
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
const size_t row_a, const size_t row_b_col_c,
const float scale, const float* HWY_RESTRICT add,
const Mat<float>& C) {
using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
using TraitsB = CompressTraits<hwy::RemoveConst<MatTB>>;
const hn::ScalableTag<float> d32;
const size_t N = hn::Lanes(d32);
using V = hn::Vec<decltype(d32)>;
V c00 = hn::Zero(d32);
V c01 = hn::Zero(d32);
V c02 = hn::Zero(d32);
@ -323,127 +327,118 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
V c32 = hn::Zero(d32);
V c33 = hn::Zero(d32);
const size_t A_tile_ofs = A_ofs + stride_a * row_a;
const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c;
const size_t A_ofs = A.Row(row_a);
const size_t B_ofs = B.Row(row_b_col_c);
// Loop over columns of A and columns of the transposed B, in steps of 2*N
// (since we are decoding consecutive bytes at each iteration).
// Accumulates into the c## vectors.
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c,
// col_ab) for B. Accumulates into the c## vectors.
size_t col_ab = 0;
HWY_UNROLL(1)
for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) {
for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
V b00, b01;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
V b10, b11;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
V b20, b21;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
V b30, b31;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
V a00, a01;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01);
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01);
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03);
if constexpr (kNumRows == 1) continue;
V a10, a11;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11);
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11);
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13);
if constexpr (kNumRows == 2) continue;
V a20, a21;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21);
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21);
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23);
if constexpr (kNumRows == 3) continue;
V a30, a31;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31);
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31);
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33);
}
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
}
// Tiled 4x4 GEMM: C = A * B * scale [+ add].
// Computes the matrix product of A and B and stores this in C. Processes tiles
// of 4x4 vectors in parallel with a work-stealing thread pool.
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
//
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
// `add` is ignored and can be nullptr.
// A is a row-major matrix of size (batch_size, colsA_rowsB).
// B is passed transposed (column-major), so a matrix of size
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC).
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
// from the pointers because some MatTA/B such as NuqStream do not support
// pointer arithmetic.
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride`
// elements between rows, which is typically the same as `colsBC`. There is no
// `C_ofs` because callers can simply add it to `C`.
// The product is scaled by `scale` to support CompressedArray with scale != 1,
// the caller can pass the product of the scales of A and B.
// A scale for `add` is not supported, so make sure its scale is 1.
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k.
// `A` is a row-major matrix of shape `(batch_size, A.cols)`.
// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
// rows in the original B, and `C.cols` the number of columns in the original B.
//
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
// product of their `.scale()` values.
//
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
// supported, so make sure its scale is 1.
//
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
// Writes 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k.
template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
const size_t colsA_rowsB,
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
const size_t colsBC, const float scale,
float* HWY_RESTRICT C, const size_t C_stride,
const float* HWY_RESTRICT add,
HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
const Mat<MatTB>& B, const float scale,
const float* HWY_RESTRICT add, const Mat<float>& C,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul");
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
constexpr size_t kRegCols = 4;
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
HWY_DASSERT(A.cols == B.cols);
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t N = hn::Lanes(hn::ScalableTag<float>());
(void)N;
HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2.
HWY_DASSERT(C.cols % kRegCols == 0);
// We currently write C directly, which touches more memory than fits in L3.
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
const hn::ScalableTag<MatTA> d;
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
(void)Nf; // For HWY_DASSERT
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
constexpr size_t kRegCols = 4; // in vectors
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
HWY_DASSERT(colsBC % kRegCols == 0);
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = colsBC / kRegCols;
const size_t strideA = colsA_rowsB;
const size_t strideB = colsA_rowsB;
const size_t tilesX = C.cols / kRegCols;
pool.Run(0, tilesX * tilesY,
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
const size_t tx = idx_tile % tilesX;
const size_t ty = idx_tile / tilesX;
const size_t row_a = ty * kRegRows;
const size_t row_b_col_c = tx * kRegCols;
// How many rows of C are left to compute. If more than 4, this
// tile still only computes 4 rows.
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows;
HWY_ASSERT(num_rows > 0);
const size_t num_rows = batch_size - row_a;
HWY_DASSERT(num_rows != 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
break;
case 2:
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
break;
case 3:
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
break;
default:
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
}
});
}

View File

@ -14,6 +14,7 @@
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
@ -48,47 +49,10 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] =
static_cast<float>((i * kInner + j + offset) * scale);
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat;
}
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<MatT, kOuter * kInner> content;
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
});
Compress(content, ws, mat, pool);
mat.set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner);
const float scale = 1.875f / (kInner * kOuter + offset);
@ -99,6 +63,8 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
}
});
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
@ -109,9 +75,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>
GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner);
const float scale = 1.875f / (kInner * kOuter + offset);
@ -122,6 +85,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
}
});
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
// Arbitrary value, different from 1, must match GenerateMatHeap.
@ -133,9 +98,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner);
@ -143,22 +105,14 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
});
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
mat->set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
}
// A simple matrix multiplication. No optimization / tiling.
template <size_t kM, size_t kN, size_t kK>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
@ -179,27 +133,6 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
return out;
}
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
const hwy::AlignedFreeUniquePtr<float[]>& vec,
const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
out[idx_row] +=
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
}
}
return out;
}
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
@ -233,8 +166,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
@ -271,92 +203,167 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
}
void PrintSpeed(const char* algo, size_t M, size_t N, size_t K,
double elapsed) {
// * 2 because of FMA.
fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed,
2E-9 * M * N * K / elapsed);
}
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
fprintf(stderr,
"TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name());
hwy::ThreadPool pool(3);
void TestMatMul(hwy::ThreadPool& pool) {
using TraitsA = CompressTraits<MatTA>;
using TraitsB = CompressTraits<MatTB>;
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", kM,
kN, kK, kAdd, TraitsA::Name(), TraitsB::Name());
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
const float scale = a->scale() * b_trans->scale();
std::unique_ptr<CompressedArray<float, kK>> add;
if (kAdd) {
add = GenerateMatHeap<float, 1, kK>(0, pool);
add->set_scale(1.0f);
}
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow;
const bool compare_slow = kN < 2048;
if (compare_slow) {
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kK>> add =
GenerateMatHeap<float, 1, kK>(0, pool);
add->set_scale(1.0f);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
const float scale = a->scale() * b->scale();
HWY_ASSERT_EQ(scale, a->scale() * b->scale());
c_slow = GenerateZeroMatHeap<float, kM, kK>(pool);
const double start_slow = hwy::platform::Now();
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data());
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
const double start_tiled = hwy::platform::Now();
EXPECT_EQ(scale, a->scale() * b_trans->scale());
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
kK, add->data(), pool);
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
AssertClose(c_slow->data(), c.get(), kM * kK);
PrintSpeed("MatMulSlowBatch", kM, kN, kK,
hwy::platform::Now() - start_slow);
}
void TestAllTiledBatchMatMul() {
double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < (compare_slow ? 1 : 3); ++rep) {
const double start_tiled = hwy::platform::Now();
MatMul_4x4<kAdd>(kM, MakeMat(a->data(), kN), MakeMat(b_trans->data(), kN),
scale, kAdd ? add->data_scale1() : nullptr,
MakeMat(c.get(), kK), pool);
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
}
PrintSpeed("MatMul_4x4", kM, kN, kK, min_elapsed);
if (compare_slow) {
AssertClose(c_slow->data(), c.get(), kM * kK);
}
}
void TestAllMatMul() {
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
return;
}
hwy::ThreadPool pool(4);
using BF16 = hwy::bfloat16_t;
using F32 = float;
using SFP = SfpStream;
// medium-sized square test
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
// minimal non-square test. kK must be at least 2 vectors.
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.
// Causes test timeout.
// TestTiledBatchMatMul<512, 24576, 3072, float>();
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool);
// medium-sized square test
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool);
// minimal non-square test. kK must be at least 2 vectors.
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
}
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
const hwy::AlignedFreeUniquePtr<float[]>& vec,
const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
out[idx_row] +=
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
}
}
return out;
}
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] =
static_cast<float>((i * kInner + j + offset) * scale);
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat;
}
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
}
void TestMatVecAdd() {
@ -441,7 +448,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(MatmulTest);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);