MatMul cleanup: Mat struct, simplify args.

Add large benchmark to test, use 4 threads, skip some targets.
Also use Traits::Name instead of typeid.

PiperOrigin-RevId: 657496185
This commit is contained in:
Jan Wassenberg 2024-07-30 01:55:09 -07:00 committed by Copybara-Service
parent d9f86f8e4d
commit 6ea4232b2e
4 changed files with 314 additions and 314 deletions

View File

@ -60,6 +60,7 @@ struct CompressTraits {};
template <> template <>
struct CompressTraits<float> { struct CompressTraits<float> {
using MatT = float; using MatT = float;
static const char* Name() { return "f32"; }
static constexpr bool kSupportsEvenOdd = false; // unnecessary static constexpr bool kSupportsEvenOdd = false; // unnecessary
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
@ -123,6 +124,7 @@ struct CompressTraits<float> {
template <> template <>
struct CompressTraits<hwy::bfloat16_t> { struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t; using MatT = hwy::bfloat16_t;
static const char* Name() { return "bf16"; }
static constexpr bool kSupportsEvenOdd = true; static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
@ -292,6 +294,7 @@ struct CompressTraits<hwy::bfloat16_t> {
template <> template <>
struct CompressTraits<SfpStream> { struct CompressTraits<SfpStream> {
using MatT = SfpStream; using MatT = SfpStream;
static const char* Name() { return "sfp"; }
static constexpr bool kSupportsEvenOdd = true; static constexpr bool kSupportsEvenOdd = true;
// Callers are responsible for scaling `in` such that its magnitudes do not // Callers are responsible for scaling `in` such that its magnitudes do not
@ -389,6 +392,7 @@ struct CompressTraits<SfpStream> {
template <> template <>
struct CompressTraits<NuqStream> { struct CompressTraits<NuqStream> {
using MatT = NuqStream; using MatT = NuqStream;
static const char* Name() { return "nuq"; }
static constexpr bool kSupportsEvenOdd = false; static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>

View File

@ -237,12 +237,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
// //
// Compute Q only or QKV (if MHA). // Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below. // If MHA, this also computes KV, which we copy to the KV cache below.
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(), MatMul_4x4</*kAdd=*/false>(
0, kModelDim, layer_weights->qkv_einsum_w.data(), num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
0, kHeads * kQStride, MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim),
layer_weights->qkv_einsum_w.scale(), layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
activations.q.All(), kHeads * kQStride, MakeMat(activations.q.All(), kHeads * kQStride), pool);
/*add=*/nullptr, pool);
// Compute KV if not MHA. // Compute KV if not MHA.
if constexpr (!kIsMHA) { if constexpr (!kIsMHA) {
@ -250,16 +249,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
// directly into the KV cache with a stride of kCachePosSize. // directly into the KV cache with a stride of kCachePosSize.
if (num_queries == 1 && if (num_queries == 1 &&
batch_start + num_tokens <= div_seq_len.GetDivisor()) { batch_start + num_tokens <= div_seq_len.GetDivisor()) {
const size_t colsBC = kKVHeads * 2 * kQKVDim;
const size_t kv_ofs = const size_t kv_ofs =
batch_start * kCachePosSize + layer * kCacheLayerSize; batch_start * kCachePosSize + layer * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>( MatMul_4x4</*kAdd=*/false>(
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim, num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim, MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim,
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize, kHeads * kQKVDim * kModelDim),
/*add=*/nullptr, pool); layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool);
} else { } else {
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
@ -441,14 +440,12 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
// MatMul expects col-major B, which is what we have: kModelDim consecutive // MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times. // elements in memory, repeated kFFHiddenDim times.
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsBC = kFFHiddenDim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All(); const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim);
const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim,
kModelDim, kModelDim * kFFHiddenDim);
const float scale = layer_weights->gating_einsum_w.scale(); const float scale = layer_weights->gating_einsum_w.scale();
const auto B1 = layer_weights->gating_einsum_w.data();
auto C1 = activations.C1.All();
auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases; constexpr bool kAddBias = TConfig::kFFBiases;
const float* bias1 = nullptr; const float* bias1 = nullptr;
const float* bias2 = nullptr; const float* bias2 = nullptr;
@ -458,24 +455,22 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
bias2 = bias1 + kFFHiddenDim; bias2 = bias1 + kFFHiddenDim;
output_bias = layer_weights->ffw_output_biases.data_scale1(); output_bias = layer_weights->ffw_output_biases.data_scale1();
} }
auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim);
auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim);
const size_t A_ofs = 0; // no offset, using the same activations for both.
// Will go through GELU. // Will go through GELU.
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1, MatMul_4x4<kAddBias>(num_interleaved, A, B1, scale, bias1, C1, pool);
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
// What to multiply by. // What to multiply by.
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1, MatMul_4x4<kAddBias>(num_interleaved, A, B2, scale, bias2, C2, pool);
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
bias2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1. // Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved); Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim, MatMul_4x4<kAddBias>(num_interleaved, C1,
layer_weights->linear_w.data(), 0, kModelDim, MakeMat(layer_weights->linear_w.data(), kFFHiddenDim),
layer_weights->linear_w.scale(), layer_weights->linear_w.scale(), output_bias,
activations.ffw_out.All(), kModelDim, output_bias, pool); MakeMat(activations.ffw_out.All(), kModelDim), pool);
} }
// `batch_idx` indicates which row of `x` to write to. // `batch_idx` indicates which row of `x` to write to.
@ -1022,12 +1017,11 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
bool all_queries_eos = true; bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations. // Compute logits from last layer activations.
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim, MatMul_4x4</*kAdd=*/false>(
weights.embedder_input_embedding.data(), 0, num_queries, MakeMat(activations.x.All(), kModelDim),
kVocabSize, MakeMat(weights.embedder_input_embedding.data(), kModelDim),
weights.embedder_input_embedding.scale(), weights.embedder_input_embedding.scale(), /*add=*/nullptr,
activations.logits.All(), kVocabSize, MakeMat(activations.logits.All(), kVocabSize), pool);
/*add=*/nullptr, pool);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
if constexpr (TConfig::kFinalCap > 0.0f) { if constexpr (TConfig::kFinalCap > 0.0f) {

View File

@ -60,7 +60,7 @@ HWY_INLINE void StoreHorizontalSums(DF df, //
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs. // expensive, but only a fraction of the A.cols/N FMAs.
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00); tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01); tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02); tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
@ -93,14 +93,14 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
VF c10, VF c11, VF c12, VF c13, // VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, // VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add,
const float scale, const float scale,
const float* HWY_RESTRICT add,
float* HWY_RESTRICT tile_c, float* HWY_RESTRICT tile_c,
size_t stride_c) { size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs. // expensive, but only a fraction of the A.cols/N FMAs.
const float add0 = add[0]; const float add0 = add[0];
// TODO: 4x4 transpose, then 128-bit vector FMA? // TODO: 4x4 transpose, then 128-bit vector FMA?
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0; tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
@ -137,12 +137,12 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsMaybeAdd( HWY_INLINE void StoreHorizontalSumsMaybeAdd(
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33, VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, size_t add_offset, const float scale, const float scale, const float* HWY_RESTRICT add, size_t add_offset,
float* HWY_RESTRICT tile_c, size_t stride_c) { float* HWY_RESTRICT tile_c, size_t stride_c) {
if constexpr (kAdd) { if constexpr (kAdd) {
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13, StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, c20, c21, c22, c23, c30, c31, c32, c33,
add + add_offset, scale, tile_c, stride_c); scale, add + add_offset, tile_c, stride_c);
} else { } else {
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13, StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, c20, c21, c22, c23, c30, c31, c32, c33,
@ -150,6 +150,36 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
} }
} }
// Wrapper to simplify call sites. T can be const or non-const.
template <typename T>
struct Mat {
bool NotEmpty() const {
return ptr != nullptr && cols != 0 && stride >= cols;
}
size_t Row(size_t r) const { return ofs + stride * r; }
T* HWY_RESTRICT ptr;
size_t cols;
// elements between rows, which is typically the same as `cols`.
size_t stride;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
return MakeMat(ptr, cols, cols);
}
#undef GEMMA_NATIVE_BF16 #undef GEMMA_NATIVE_BF16
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ #if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
defined(HWY_TARGET_TOGGLE)) defined(HWY_TARGET_TOGGLE))
@ -162,31 +192,18 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32. // Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
template <size_t kNumRows, bool kAdd> template <size_t kNumRows, bool kAdd>
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, HWY_INLINE void MatMulTile(const Mat<const hwy::bfloat16_t>& A,
const size_t A_ofs, const Mat<const hwy::bfloat16_t>& B,
const hwy::bfloat16_t* HWY_RESTRICT B, const size_t row_a, const size_t row_b_col_c,
const size_t B_ofs, float* HWY_RESTRICT C,
const float scale, const float* HWY_RESTRICT add, const float scale, const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles, const Mat<float>& C) {
const size_t cols_a, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
// bf16 vectors. // bf16 vectors.
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d; const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
VF unused_sum1 = hn::Zero(df);
const size_t N = Lanes(d); const size_t N = Lanes(d);
VF unused_sum1 = hn::Zero(df);
VF c00 = hn::Zero(df); VF c00 = hn::Zero(df);
VF c01 = hn::Zero(df); VF c01 = hn::Zero(df);
VF c02 = hn::Zero(df); VF c02 = hn::Zero(df);
@ -207,42 +224,41 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
VF c32 = hn::Zero(df); VF c32 = hn::Zero(df);
VF c33 = hn::Zero(df); VF c33 = hn::Zero(df);
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a; const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a);
const hwy::bfloat16_t* HWY_RESTRICT B_tile = const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
B + B_ofs + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N. // Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors. // Accumulates into the c## vectors.
HWY_UNROLL(1) HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) { for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) {
using V = hn::Vec<decltype(d)>; using V = hn::Vec<decltype(d)>;
const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab); const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab); const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab); const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab); const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab); const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1); c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
if constexpr (kNumRows == 1) continue; if constexpr (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab); const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
if constexpr (kNumRows == 2) continue; if constexpr (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab); const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
if constexpr (kNumRows == 3) continue; if constexpr (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab); const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1); c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1); c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
@ -252,10 +268,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
// Ensure sum1 was indeed unused. // Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>( StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, scale, C_tile, stride_c); c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
} }
#endif // GEMMA_NATIVE_BF16 #endif // GEMMA_NATIVE_BF16
@ -277,32 +293,20 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
c3 = hn::MulAdd(a1, b31, c3); c3 = hn::MulAdd(a1, b31, c3);
} }
// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is // Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
// transposed, so we iterate over both A and B with consecutive vector loads. // finished tile of `C`.
// General case: uses CompressTraits to load from A and B. // General case: uses CompressTraits to load from A and B.
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB> template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs, HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
const MatTB* HWY_RESTRICT B, const size_t B_ofs, const size_t row_a, const size_t row_b_col_c,
float* HWY_RESTRICT C, const float scale, const float scale, const float* HWY_RESTRICT add,
const float* HWY_RESTRICT add, const Mat<float>& C) {
const size_t idx_tile, const size_t xtiles, using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
const size_t cols_a, const size_t stride_a, using TraitsB = CompressTraits<hwy::RemoveConst<MatTB>>;
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
using TraitsA = CompressTraits<MatTA>;
using TraitsB = CompressTraits<MatTB>;
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> d32; const hn::ScalableTag<float> d32;
const size_t N = hn::Lanes(d32); const size_t N = hn::Lanes(d32);
using V = hn::Vec<decltype(d32)>; using V = hn::Vec<decltype(d32)>;
V c00 = hn::Zero(d32); V c00 = hn::Zero(d32);
V c01 = hn::Zero(d32); V c01 = hn::Zero(d32);
V c02 = hn::Zero(d32); V c02 = hn::Zero(d32);
@ -323,127 +327,118 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
V c32 = hn::Zero(d32); V c32 = hn::Zero(d32);
V c33 = hn::Zero(d32); V c33 = hn::Zero(d32);
const size_t A_tile_ofs = A_ofs + stride_a * row_a; const size_t A_ofs = A.Row(row_a);
const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c; const size_t B_ofs = B.Row(row_b_col_c);
// Loop over columns of A and columns of the transposed B, in steps of 2*N // Loop over columns of A and columns of the transposed B, in steps of 2*N
// (since we are decoding consecutive bytes at each iteration). // (since we are decoding consecutive bytes at each iteration).
// Accumulates into the c## vectors. // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c,
// col_ab) for B. Accumulates into the c## vectors.
size_t col_ab = 0; size_t col_ab = 0;
HWY_UNROLL(1) HWY_UNROLL(1)
for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) { for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
V b00, b01; V b00, b01;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01); TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
V b10, b11; V b10, b11;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11); TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
V b20, b21; V b20, b21;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21); TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
V b30, b31; V b30, b31;
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31); TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
V a00, a01; V a00, a01;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01); TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01);
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03); c02, c03);
if constexpr (kNumRows == 1) continue; if constexpr (kNumRows == 1) continue;
V a10, a11; V a10, a11;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11); TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11);
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13); c12, c13);
if constexpr (kNumRows == 2) continue; if constexpr (kNumRows == 2) continue;
V a20, a21; V a20, a21;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21); TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21);
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23); c22, c23);
if constexpr (kNumRows == 3) continue; if constexpr (kNumRows == 3) continue;
V a30, a31; V a30, a31;
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31); TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31);
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33); c32, c33);
} }
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>( StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, scale, C_tile, stride_c); c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
} }
// Tiled 4x4 GEMM: C = A * B * scale [+ add]. // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
// Computes the matrix product of A and B and stores this in C. Processes tiles
// of 4x4 vectors in parallel with a work-stealing thread pool.
// //
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise // `A` is a row-major matrix of shape `(batch_size, A.cols)`.
// `add` is ignored and can be nullptr. // `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
// A is a row-major matrix of size (batch_size, colsA_rowsB). // rows in the original B, and `C.cols` the number of columns in the original B.
// B is passed transposed (column-major), so a matrix of size //
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC). // `scale` allows expanding the smaller range of `SfpStream` to the original
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate // values. When `A` and/or `B` are from CompressedArray, `scale` should be the
// from the pointers because some MatTA/B such as NuqStream do not support // product of their `.scale()` values.
// pointer arithmetic. //
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride` // If `kAdd` is true, the row-vector `add` is added to each row of `C`,
// elements between rows, which is typically the same as `colsBC`. There is no // otherwise `add` is ignored and can be nullptr. A scale for `add` is not
// `C_ofs` because callers can simply add it to `C`. // supported, so make sure its scale is 1.
// The product is scaled by `scale` to support CompressedArray with scale != 1, //
// the caller can pass the product of the scales of A and B. // `C` is a row-major matrix of size `(batch_size, C.cols)`.
// A scale for `add` is not supported, so make sure its scale is 1. // Writes 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k. // Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k.
template <bool kAdd, typename MatTA, typename MatTB> template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul_4x4(const size_t batch_size, HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
const MatTA* HWY_RESTRICT A, const size_t A_ofs, const Mat<MatTB>& B, const float scale,
const size_t colsA_rowsB, const float* HWY_RESTRICT add, const Mat<float>& C,
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
const size_t colsBC, const float scale,
float* HWY_RESTRICT C, const size_t C_stride,
const float* HWY_RESTRICT add,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul"); PROFILER_ZONE("Matmul");
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
constexpr size_t kRegCols = 4;
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
HWY_DASSERT(A.cols == B.cols);
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t N = hn::Lanes(hn::ScalableTag<float>());
(void)N;
HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2.
HWY_DASSERT(C.cols % kRegCols == 0);
// We currently write C directly, which touches more memory than fits in L3. // We currently write C directly, which touches more memory than fits in L3.
// TODO: add another level of loops to finish L3-sized pieces of C at a time. // TODO: add another level of loops to finish L3-sized pieces of C at a time.
const hn::ScalableTag<MatTA> d;
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
(void)Nf; // For HWY_DASSERT
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
constexpr size_t kRegCols = 4; // in vectors
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
HWY_DASSERT(colsBC % kRegCols == 0);
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = colsBC / kRegCols; const size_t tilesX = C.cols / kRegCols;
const size_t strideA = colsA_rowsB;
const size_t strideB = colsA_rowsB;
pool.Run(0, tilesX * tilesY, pool.Run(0, tilesX * tilesY,
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
const size_t tx = idx_tile % tilesX;
const size_t ty = idx_tile / tilesX;
const size_t row_a = ty * kRegRows;
const size_t row_b_col_c = tx * kRegCols;
// How many rows of C are left to compute. If more than 4, this // How many rows of C are left to compute. If more than 4, this
// tile still only computes 4 rows. // tile still only computes 4 rows.
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows; const size_t num_rows = batch_size - row_a;
HWY_ASSERT(num_rows > 0); HWY_DASSERT(num_rows != 0);
switch (num_rows) { switch (num_rows) {
case 1: case 1:
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break; break;
case 2: case 2:
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break; break;
case 3: case 3:
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
break; break;
default: default:
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
idx_tile, tilesX, colsA_rowsB, strideA,
strideB, C_stride);
} }
}); });
} }

View File

@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
@ -48,47 +49,10 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] =
static_cast<float>((i * kInner + j + offset) * scale);
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat;
}
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<MatT, kOuter * kInner> content;
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
});
Compress(content, ws, mat, pool);
mat.set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
template <typename MatT, size_t kOuter, size_t kInner> template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap( std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
size_t offset, hwy::ThreadPool& pool) { size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content = hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner); hwy::AllocateAligned<float>(kOuter * kInner);
const float scale = 1.875f / (kInner * kOuter + offset); const float scale = 1.875f / (kInner * kOuter + offset);
@ -99,6 +63,8 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
} }
}); });
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1. mat->set_scale(0.6f); // Arbitrary value, different from 1.
@ -109,9 +75,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>
GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content = hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner); hwy::AllocateAligned<float>(kOuter * kInner);
const float scale = 1.875f / (kInner * kOuter + offset); const float scale = 1.875f / (kInner * kOuter + offset);
@ -122,6 +85,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
} }
}); });
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
// Arbitrary value, different from 1, must match GenerateMatHeap. // Arbitrary value, different from 1, must match GenerateMatHeap.
@ -133,9 +98,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap( std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
new CompressedArray<MatT, kOuter * kInner>);
hwy::AlignedFreeUniquePtr<float[]> content = hwy::AlignedFreeUniquePtr<float[]> content =
hwy::AllocateAligned<float>(kOuter * kInner); hwy::AllocateAligned<float>(kOuter * kInner);
@ -143,22 +105,14 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0])); hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
}); });
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
mat->set_scale(1.2f); // Arbitrary value, different from 1. mat->set_scale(1.2f); // Arbitrary value, different from 1.
return mat; return mat;
} }
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
}
// A simple matrix multiplication. No optimization / tiling. // A simple matrix multiplication. No optimization / tiling.
template <size_t kM, size_t kN, size_t kK> template <size_t kM, size_t kN, size_t kK>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul( hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
@ -179,27 +133,6 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
return out; return out;
} }
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
const hwy::AlignedFreeUniquePtr<float[]>& vec,
const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
out[idx_row] +=
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
}
}
return out;
}
template <typename MatT> template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected, void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) { const MatT* HWY_RESTRICT actual, size_t num) {
@ -233,8 +166,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
} }
} }
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB, template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)> HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
@ -271,92 +203,167 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out); MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
} }
void PrintSpeed(const char* algo, size_t M, size_t N, size_t K,
double elapsed) {
// * 2 because of FMA.
fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed,
2E-9 * M * N * K / elapsed);
}
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA, template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
typename MatTB = MatTA> typename MatTB = MatTA>
void TestTiledBatchMatMul() { void TestMatMul(hwy::ThreadPool& pool) {
fprintf(stderr, using TraitsA = CompressTraits<MatTA>;
"TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", using TraitsB = CompressTraits<MatTB>;
kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name()); fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", kM,
hwy::ThreadPool pool(3); kN, kK, kAdd, TraitsA::Name(), TraitsB::Name());
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a = std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool); GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
const float scale = a->scale() * b_trans->scale();
std::unique_ptr<CompressedArray<float, kK>> add;
if (kAdd) {
add = GenerateMatHeap<float, 1, kK>(0, pool);
add->set_scale(1.0f);
}
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow;
const bool compare_slow = kN < 2048;
if (compare_slow) {
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b = std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
GenerateMatHeap<MatTB, kN, kK>(0, pool); GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kK>> add = HWY_ASSERT_EQ(scale, a->scale() * b->scale());
GenerateMatHeap<float, 1, kK>(0, pool); c_slow = GenerateZeroMatHeap<float, kM, kK>(pool);
add->set_scale(1.0f);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
const float scale = a->scale() * b->scale();
const double start_slow = hwy::platform::Now(); const double start_slow = hwy::platform::Now();
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale, MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data()); kAdd ? add->data() : nullptr, c_slow->data());
const double slow_matmul_seconds = hwy::platform::Now() - start_slow; PrintSpeed("MatMulSlowBatch", kM, kN, kK,
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds); hwy::platform::Now() - start_slow);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
const double start_tiled = hwy::platform::Now();
EXPECT_EQ(scale, a->scale() * b_trans->scale());
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
kK, add->data(), pool);
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
AssertClose(c_slow->data(), c.get(), kM * kK);
} }
void TestAllTiledBatchMatMul() { double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < (compare_slow ? 1 : 3); ++rep) {
const double start_tiled = hwy::platform::Now();
MatMul_4x4<kAdd>(kM, MakeMat(a->data(), kN), MakeMat(b_trans->data(), kN),
scale, kAdd ? add->data_scale1() : nullptr,
MakeMat(c.get(), kK), pool);
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
}
PrintSpeed("MatMul_4x4", kM, kN, kK, min_elapsed);
if (compare_slow) {
AssertClose(c_slow->data(), c.get(), kM * kK);
}
}
void TestAllMatMul() {
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
return;
}
hwy::ThreadPool pool(4);
using BF16 = hwy::bfloat16_t; using BF16 = hwy::bfloat16_t;
using F32 = float; using F32 = float;
using SFP = SfpStream; using SFP = SfpStream;
// medium-sized square test
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
// minimal non-square test. kK must be at least 2 vectors.
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
// large-scale test // large-scale test
// TODO(philculliton): investigate rounding issues with large matrices. TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool);
// Causes test timeout.
// TestTiledBatchMatMul<512, 24576, 3072, float>(); // medium-sized square test
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool);
// minimal non-square test. kK must be at least 2 vectors.
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
}
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
const hwy::AlignedFreeUniquePtr<float[]>& vec,
const hwy::AlignedFreeUniquePtr<float[]>& add) {
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
hwy::AllocateAligned<float>(kOuter * kInner);
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
out[idx_row] +=
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
}
}
return out;
}
template <typename MatT, size_t kOuter, size_t kInner>
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
CompressedArray<MatT, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] =
static_cast<float>((i * kInner + j + offset) * scale);
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat;
}
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
HWY_ASSERT(vec);
for (size_t idx = 0; idx < length; idx++) {
vec[idx] = static_cast<float>(idx + offset);
}
return vec;
} }
void TestMatVecAdd() { void TestMatVecAdd() {
@ -441,7 +448,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp { namespace gcpp {
HWY_BEFORE_TEST(MatmulTest); HWY_BEFORE_TEST(MatmulTest);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul); HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop); HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);