mirror of https://github.com/google/gemma.cpp.git
MatMul cleanup: Mat struct, simplify args.
Add large benchmark to test, use 4 threads, skip some targets. Also use Traits::Name instead of typeid. PiperOrigin-RevId: 657496185
This commit is contained in:
parent
d9f86f8e4d
commit
6ea4232b2e
|
|
@ -60,6 +60,7 @@ struct CompressTraits {};
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<float> {
|
struct CompressTraits<float> {
|
||||||
using MatT = float;
|
using MatT = float;
|
||||||
|
static const char* Name() { return "f32"; }
|
||||||
static constexpr bool kSupportsEvenOdd = false; // unnecessary
|
static constexpr bool kSupportsEvenOdd = false; // unnecessary
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
|
@ -123,6 +124,7 @@ struct CompressTraits<float> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<hwy::bfloat16_t> {
|
struct CompressTraits<hwy::bfloat16_t> {
|
||||||
using MatT = hwy::bfloat16_t;
|
using MatT = hwy::bfloat16_t;
|
||||||
|
static const char* Name() { return "bf16"; }
|
||||||
static constexpr bool kSupportsEvenOdd = true;
|
static constexpr bool kSupportsEvenOdd = true;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
|
@ -292,6 +294,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<SfpStream> {
|
struct CompressTraits<SfpStream> {
|
||||||
using MatT = SfpStream;
|
using MatT = SfpStream;
|
||||||
|
static const char* Name() { return "sfp"; }
|
||||||
static constexpr bool kSupportsEvenOdd = true;
|
static constexpr bool kSupportsEvenOdd = true;
|
||||||
|
|
||||||
// Callers are responsible for scaling `in` such that its magnitudes do not
|
// Callers are responsible for scaling `in` such that its magnitudes do not
|
||||||
|
|
@ -389,6 +392,7 @@ struct CompressTraits<SfpStream> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<NuqStream> {
|
struct CompressTraits<NuqStream> {
|
||||||
using MatT = NuqStream;
|
using MatT = NuqStream;
|
||||||
|
static const char* Name() { return "nuq"; }
|
||||||
static constexpr bool kSupportsEvenOdd = false;
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
|
|
||||||
|
|
@ -237,12 +237,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
//
|
//
|
||||||
// Compute Q only or QKV (if MHA).
|
// Compute Q only or QKV (if MHA).
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
MatMul_4x4</*kAdd=*/false>(num_interleaved, activations.pre_att_rms_out.All(),
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
0, kModelDim, layer_weights->qkv_einsum_w.data(),
|
num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
|
||||||
0, kHeads * kQStride,
|
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim),
|
||||||
layer_weights->qkv_einsum_w.scale(),
|
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||||
activations.q.All(), kHeads * kQStride,
|
MakeMat(activations.q.All(), kHeads * kQStride), pool);
|
||||||
/*add=*/nullptr, pool);
|
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
// Compute KV if not MHA.
|
||||||
if constexpr (!kIsMHA) {
|
if constexpr (!kIsMHA) {
|
||||||
|
|
@ -250,16 +249,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
// directly into the KV cache with a stride of kCachePosSize.
|
// directly into the KV cache with a stride of kCachePosSize.
|
||||||
if (num_queries == 1 &&
|
if (num_queries == 1 &&
|
||||||
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
|
batch_start + num_tokens <= div_seq_len.GetDivisor()) {
|
||||||
const size_t colsBC = kKVHeads * 2 * kQKVDim;
|
|
||||||
const size_t kv_ofs =
|
const size_t kv_ofs =
|
||||||
batch_start * kCachePosSize + layer * kCacheLayerSize;
|
batch_start * kCachePosSize + layer * kCacheLayerSize;
|
||||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
|
float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs;
|
||||||
MatMul_4x4</*kAdd=*/false>(
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
num_tokens, activations.pre_att_rms_out.All(), 0, kModelDim,
|
num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim),
|
||||||
layer_weights->qkv_einsum_w.data(), kHeads * kQKVDim * kModelDim,
|
MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim,
|
||||||
colsBC, layer_weights->qkv_einsum_w.scale(), kv, kCachePosSize,
|
kHeads * kQKVDim * kModelDim),
|
||||||
/*add=*/nullptr, pool);
|
layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||||
|
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool);
|
||||||
} else {
|
} else {
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
|
|
@ -441,14 +440,12 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
|
|
||||||
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
||||||
// elements in memory, repeated kFFHiddenDim times.
|
// elements in memory, repeated kFFHiddenDim times.
|
||||||
constexpr size_t kColsA = kModelDim;
|
|
||||||
constexpr size_t kColsBC = kFFHiddenDim;
|
|
||||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||||
|
const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim);
|
||||||
|
const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim,
|
||||||
|
kModelDim, kModelDim * kFFHiddenDim);
|
||||||
const float scale = layer_weights->gating_einsum_w.scale();
|
const float scale = layer_weights->gating_einsum_w.scale();
|
||||||
const auto B1 = layer_weights->gating_einsum_w.data();
|
|
||||||
auto C1 = activations.C1.All();
|
|
||||||
auto C2 = activations.C2.All();
|
|
||||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||||
const float* bias1 = nullptr;
|
const float* bias1 = nullptr;
|
||||||
const float* bias2 = nullptr;
|
const float* bias2 = nullptr;
|
||||||
|
|
@ -458,24 +455,22 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||||
bias2 = bias1 + kFFHiddenDim;
|
bias2 = bias1 + kFFHiddenDim;
|
||||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||||
}
|
}
|
||||||
|
auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim);
|
||||||
|
auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim);
|
||||||
|
|
||||||
const size_t A_ofs = 0; // no offset, using the same activations for both.
|
|
||||||
// Will go through GELU.
|
// Will go through GELU.
|
||||||
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
MatMul_4x4<kAddBias>(num_interleaved, A, B1, scale, bias1, C1, pool);
|
||||||
/*B_ofs=*/0, kColsBC, scale, C1, kColsBC, bias1, pool);
|
|
||||||
// What to multiply by.
|
// What to multiply by.
|
||||||
MatMul_4x4<kAddBias>(num_interleaved, A, A_ofs, kColsA, B1,
|
MatMul_4x4<kAddBias>(num_interleaved, A, B2, scale, bias2, C2, pool);
|
||||||
/*B_ofs=*/kColsA * kColsBC, kColsBC, scale, C2, kColsBC,
|
|
||||||
bias2, pool);
|
|
||||||
|
|
||||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul_4x4<kAddBias>(num_interleaved, C1, 0, kFFHiddenDim,
|
MatMul_4x4<kAddBias>(num_interleaved, C1,
|
||||||
layer_weights->linear_w.data(), 0, kModelDim,
|
MakeMat(layer_weights->linear_w.data(), kFFHiddenDim),
|
||||||
layer_weights->linear_w.scale(),
|
layer_weights->linear_w.scale(), output_bias,
|
||||||
activations.ffw_out.All(), kModelDim, output_bias, pool);
|
MakeMat(activations.ffw_out.All(), kModelDim), pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// `batch_idx` indicates which row of `x` to write to.
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
|
|
@ -1022,12 +1017,11 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
bool all_queries_eos = true;
|
bool all_queries_eos = true;
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
MatMul_4x4</*kAdd=*/false>(num_queries, activations.x.All(), 0, kModelDim,
|
MatMul_4x4</*kAdd=*/false>(
|
||||||
weights.embedder_input_embedding.data(), 0,
|
num_queries, MakeMat(activations.x.All(), kModelDim),
|
||||||
kVocabSize,
|
MakeMat(weights.embedder_input_embedding.data(), kModelDim),
|
||||||
weights.embedder_input_embedding.scale(),
|
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||||
activations.logits.All(), kVocabSize,
|
MakeMat(activations.logits.All(), kVocabSize), pool);
|
||||||
/*add=*/nullptr, pool);
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
|
|
|
||||||
243
ops/matmul-inl.h
243
ops/matmul-inl.h
|
|
@ -60,7 +60,7 @@ HWY_INLINE void StoreHorizontalSums(DF df, //
|
||||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
// expensive, but only a fraction of the A.cols/N FMAs.
|
||||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
|
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
|
||||||
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
|
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
|
||||||
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
|
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
|
||||||
|
|
@ -93,14 +93,14 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
|
||||||
VF c10, VF c11, VF c12, VF c13, //
|
VF c10, VF c11, VF c12, VF c13, //
|
||||||
VF c20, VF c21, VF c22, VF c23, //
|
VF c20, VF c21, VF c22, VF c23, //
|
||||||
VF c30, VF c31, VF c32, VF c33,
|
VF c30, VF c31, VF c32, VF c33,
|
||||||
const float* HWY_RESTRICT add,
|
|
||||||
const float scale,
|
const float scale,
|
||||||
|
const float* HWY_RESTRICT add,
|
||||||
float* HWY_RESTRICT tile_c,
|
float* HWY_RESTRICT tile_c,
|
||||||
size_t stride_c) {
|
size_t stride_c) {
|
||||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
// expensive, but only a fraction of the A.cols/N FMAs.
|
||||||
const float add0 = add[0];
|
const float add0 = add[0];
|
||||||
// TODO: 4x4 transpose, then 128-bit vector FMA?
|
// TODO: 4x4 transpose, then 128-bit vector FMA?
|
||||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
|
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
|
||||||
|
|
@ -137,12 +137,12 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||||
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
||||||
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
||||||
const float* HWY_RESTRICT add, size_t add_offset, const float scale,
|
const float scale, const float* HWY_RESTRICT add, size_t add_offset,
|
||||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||||
if constexpr (kAdd) {
|
if constexpr (kAdd) {
|
||||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||||
add + add_offset, scale, tile_c, stride_c);
|
scale, add + add_offset, tile_c, stride_c);
|
||||||
} else {
|
} else {
|
||||||
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||||
|
|
@ -150,6 +150,36 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrapper to simplify call sites. T can be const or non-const.
|
||||||
|
template <typename T>
|
||||||
|
struct Mat {
|
||||||
|
bool NotEmpty() const {
|
||||||
|
return ptr != nullptr && cols != 0 && stride >= cols;
|
||||||
|
}
|
||||||
|
size_t Row(size_t r) const { return ofs + stride * r; }
|
||||||
|
|
||||||
|
T* HWY_RESTRICT ptr;
|
||||||
|
size_t cols;
|
||||||
|
|
||||||
|
// elements between rows, which is typically the same as `cols`.
|
||||||
|
size_t stride;
|
||||||
|
|
||||||
|
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||||
|
// pointer arithmetic.
|
||||||
|
size_t ofs;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||||
|
size_t ofs = 0) {
|
||||||
|
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
|
||||||
|
return MakeMat(ptr, cols, cols);
|
||||||
|
}
|
||||||
|
|
||||||
#undef GEMMA_NATIVE_BF16
|
#undef GEMMA_NATIVE_BF16
|
||||||
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
||||||
defined(HWY_TARGET_TOGGLE))
|
defined(HWY_TARGET_TOGGLE))
|
||||||
|
|
@ -162,31 +192,18 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||||
|
|
||||||
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
|
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
|
||||||
template <size_t kNumRows, bool kAdd>
|
template <size_t kNumRows, bool kAdd>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
HWY_INLINE void MatMulTile(const Mat<const hwy::bfloat16_t>& A,
|
||||||
const size_t A_ofs,
|
const Mat<const hwy::bfloat16_t>& B,
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT B,
|
const size_t row_a, const size_t row_b_col_c,
|
||||||
const size_t B_ofs, float* HWY_RESTRICT C,
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
const float scale, const float* HWY_RESTRICT add,
|
const Mat<float>& C) {
|
||||||
const size_t idx_tile, const size_t xtiles,
|
|
||||||
const size_t cols_a, const size_t stride_a,
|
|
||||||
const size_t stride_b, const size_t stride_c) {
|
|
||||||
constexpr size_t kRegRows = 4;
|
|
||||||
constexpr size_t kRegCols = 4;
|
|
||||||
static_assert(kNumRows <= kRegRows);
|
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
|
||||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||||
// bf16 vectors.
|
// bf16 vectors.
|
||||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
|
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
|
||||||
VF unused_sum1 = hn::Zero(df);
|
|
||||||
|
|
||||||
const size_t N = Lanes(d);
|
const size_t N = Lanes(d);
|
||||||
|
VF unused_sum1 = hn::Zero(df);
|
||||||
VF c00 = hn::Zero(df);
|
VF c00 = hn::Zero(df);
|
||||||
VF c01 = hn::Zero(df);
|
VF c01 = hn::Zero(df);
|
||||||
VF c02 = hn::Zero(df);
|
VF c02 = hn::Zero(df);
|
||||||
|
|
@ -207,42 +224,41 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||||
VF c32 = hn::Zero(df);
|
VF c32 = hn::Zero(df);
|
||||||
VF c33 = hn::Zero(df);
|
VF c33 = hn::Zero(df);
|
||||||
|
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a;
|
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a);
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT B_tile =
|
const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
|
||||||
B + B_ofs + stride_b * row_b_col_c;
|
|
||||||
|
|
||||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||||
// Accumulates into the c## vectors.
|
// Accumulates into the c## vectors.
|
||||||
HWY_UNROLL(1)
|
HWY_UNROLL(1)
|
||||||
for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) {
|
for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) {
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab);
|
const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||||
const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab);
|
const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||||
const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab);
|
const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||||
const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab);
|
const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||||
|
|
||||||
const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab);
|
const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||||
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
|
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
|
||||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||||
if constexpr (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab);
|
const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||||
if constexpr (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab);
|
const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||||
if constexpr (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab);
|
const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||||
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
||||||
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
||||||
|
|
@ -252,10 +268,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||||
// Ensure sum1 was indeed unused.
|
// Ensure sum1 was indeed unused.
|
||||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||||
|
|
||||||
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
|
||||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||||
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
|
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // GEMMA_NATIVE_BF16
|
#endif // GEMMA_NATIVE_BF16
|
||||||
|
|
@ -277,32 +293,20 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||||
c3 = hn::MulAdd(a1, b31, c3);
|
c3 = hn::MulAdd(a1, b31, c3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is
|
// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
|
||||||
// transposed, so we iterate over both A and B with consecutive vector loads.
|
// finished tile of `C`.
|
||||||
// General case: uses CompressTraits to load from A and B.
|
// General case: uses CompressTraits to load from A and B.
|
||||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
|
||||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
const size_t row_a, const size_t row_b_col_c,
|
||||||
float* HWY_RESTRICT C, const float scale,
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
const float* HWY_RESTRICT add,
|
const Mat<float>& C) {
|
||||||
const size_t idx_tile, const size_t xtiles,
|
using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
|
||||||
const size_t cols_a, const size_t stride_a,
|
using TraitsB = CompressTraits<hwy::RemoveConst<MatTB>>;
|
||||||
const size_t stride_b, const size_t stride_c) {
|
|
||||||
constexpr size_t kRegRows = 4;
|
|
||||||
constexpr size_t kRegCols = 4;
|
|
||||||
static_assert(kNumRows <= kRegRows);
|
|
||||||
|
|
||||||
using TraitsA = CompressTraits<MatTA>;
|
|
||||||
using TraitsB = CompressTraits<MatTB>;
|
|
||||||
|
|
||||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
|
||||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
|
||||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
|
||||||
|
|
||||||
const hn::ScalableTag<float> d32;
|
const hn::ScalableTag<float> d32;
|
||||||
const size_t N = hn::Lanes(d32);
|
const size_t N = hn::Lanes(d32);
|
||||||
using V = hn::Vec<decltype(d32)>;
|
using V = hn::Vec<decltype(d32)>;
|
||||||
|
|
||||||
V c00 = hn::Zero(d32);
|
V c00 = hn::Zero(d32);
|
||||||
V c01 = hn::Zero(d32);
|
V c01 = hn::Zero(d32);
|
||||||
V c02 = hn::Zero(d32);
|
V c02 = hn::Zero(d32);
|
||||||
|
|
@ -323,127 +327,118 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||||
V c32 = hn::Zero(d32);
|
V c32 = hn::Zero(d32);
|
||||||
V c33 = hn::Zero(d32);
|
V c33 = hn::Zero(d32);
|
||||||
|
|
||||||
const size_t A_tile_ofs = A_ofs + stride_a * row_a;
|
const size_t A_ofs = A.Row(row_a);
|
||||||
const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c;
|
const size_t B_ofs = B.Row(row_b_col_c);
|
||||||
|
|
||||||
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
||||||
// (since we are decoding consecutive bytes at each iteration).
|
// (since we are decoding consecutive bytes at each iteration).
|
||||||
// Accumulates into the c## vectors.
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c,
|
||||||
|
// col_ab) for B. Accumulates into the c## vectors.
|
||||||
size_t col_ab = 0;
|
size_t col_ab = 0;
|
||||||
|
|
||||||
HWY_UNROLL(1)
|
HWY_UNROLL(1)
|
||||||
for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) {
|
for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
|
||||||
V b00, b01;
|
V b00, b01;
|
||||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01);
|
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
|
||||||
V b10, b11;
|
V b10, b11;
|
||||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11);
|
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
|
||||||
V b20, b21;
|
V b20, b21;
|
||||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21);
|
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
|
||||||
V b30, b31;
|
V b30, b31;
|
||||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31);
|
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
|
||||||
|
|
||||||
V a00, a01;
|
V a00, a01;
|
||||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01);
|
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01);
|
||||||
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||||
c02, c03);
|
c02, c03);
|
||||||
if constexpr (kNumRows == 1) continue;
|
if constexpr (kNumRows == 1) continue;
|
||||||
|
|
||||||
V a10, a11;
|
V a10, a11;
|
||||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11);
|
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11);
|
||||||
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||||
c12, c13);
|
c12, c13);
|
||||||
if constexpr (kNumRows == 2) continue;
|
if constexpr (kNumRows == 2) continue;
|
||||||
|
|
||||||
V a20, a21;
|
V a20, a21;
|
||||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21);
|
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21);
|
||||||
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||||
c22, c23);
|
c22, c23);
|
||||||
if constexpr (kNumRows == 3) continue;
|
if constexpr (kNumRows == 3) continue;
|
||||||
|
|
||||||
V a30, a31;
|
V a30, a31;
|
||||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31);
|
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31);
|
||||||
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||||
c32, c33);
|
c32, c33);
|
||||||
}
|
}
|
||||||
|
|
||||||
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
|
||||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||||
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
|
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tiled 4x4 GEMM: C = A * B * scale [+ add].
|
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||||
// Computes the matrix product of A and B and stores this in C. Processes tiles
|
|
||||||
// of 4x4 vectors in parallel with a work-stealing thread pool.
|
|
||||||
//
|
//
|
||||||
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
|
// `A` is a row-major matrix of shape `(batch_size, A.cols)`.
|
||||||
// `add` is ignored and can be nullptr.
|
// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
|
||||||
// A is a row-major matrix of size (batch_size, colsA_rowsB).
|
// rows in the original B, and `C.cols` the number of columns in the original B.
|
||||||
// B is passed transposed (column-major), so a matrix of size
|
//
|
||||||
// (colsBC, colsA_rowsB), representing a B of size (colsA_rowsB, colsBC).
|
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||||
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
|
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
|
||||||
// from the pointers because some MatTA/B such as NuqStream do not support
|
// product of their `.scale()` values.
|
||||||
// pointer arithmetic.
|
//
|
||||||
// C is a row-major matrix of size (batch_size, colsBC), with `C_stride`
|
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
|
||||||
// elements between rows, which is typically the same as `colsBC`. There is no
|
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
|
||||||
// `C_ofs` because callers can simply add it to `C`.
|
// supported, so make sure its scale is 1.
|
||||||
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
//
|
||||||
// the caller can pass the product of the scales of A and B.
|
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
|
||||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
// Writes 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||||
// Typically batch_size is 1..512, colsA_rowsB and colsBC are 3k or 24k.
|
// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k.
|
||||||
template <bool kAdd, typename MatTA, typename MatTB>
|
template <bool kAdd, typename MatTA, typename MatTB>
|
||||||
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
|
HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
|
||||||
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
const Mat<MatTB>& B, const float scale,
|
||||||
const size_t colsA_rowsB,
|
const float* HWY_RESTRICT add, const Mat<float>& C,
|
||||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
|
||||||
const size_t colsBC, const float scale,
|
|
||||||
float* HWY_RESTRICT C, const size_t C_stride,
|
|
||||||
const float* HWY_RESTRICT add,
|
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Matmul");
|
PROFILER_ZONE("Matmul");
|
||||||
|
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
|
||||||
|
constexpr size_t kRegCols = 4;
|
||||||
|
|
||||||
|
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
|
||||||
|
HWY_DASSERT(A.cols == B.cols);
|
||||||
|
|
||||||
|
// Use float instead of MatTA/MatTB because we decompress to float here.
|
||||||
|
const size_t N = hn::Lanes(hn::ScalableTag<float>());
|
||||||
|
(void)N;
|
||||||
|
HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2.
|
||||||
|
HWY_DASSERT(C.cols % kRegCols == 0);
|
||||||
|
|
||||||
// We currently write C directly, which touches more memory than fits in L3.
|
// We currently write C directly, which touches more memory than fits in L3.
|
||||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||||
const hn::ScalableTag<MatTA> d;
|
|
||||||
// Use float instead of MatTA/MatTB because we decompress to float here.
|
|
||||||
const size_t Nf = hn::Lanes(hn::ScalableTag<float>());
|
|
||||||
(void)Nf; // For HWY_DASSERT
|
|
||||||
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
|
|
||||||
constexpr size_t kRegCols = 4; // in vectors
|
|
||||||
|
|
||||||
HWY_DASSERT(colsA_rowsB % (Nf * 2) == 0); // For Decompress2.
|
|
||||||
HWY_DASSERT(colsBC % kRegCols == 0);
|
|
||||||
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||||
const size_t tilesX = colsBC / kRegCols;
|
const size_t tilesX = C.cols / kRegCols;
|
||||||
|
|
||||||
const size_t strideA = colsA_rowsB;
|
|
||||||
const size_t strideB = colsA_rowsB;
|
|
||||||
|
|
||||||
pool.Run(0, tilesX * tilesY,
|
pool.Run(0, tilesX * tilesY,
|
||||||
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t tx = idx_tile % tilesX;
|
||||||
|
const size_t ty = idx_tile / tilesX;
|
||||||
|
const size_t row_a = ty * kRegRows;
|
||||||
|
const size_t row_b_col_c = tx * kRegCols;
|
||||||
// How many rows of C are left to compute. If more than 4, this
|
// How many rows of C are left to compute. If more than 4, this
|
||||||
// tile still only computes 4 rows.
|
// tile still only computes 4 rows.
|
||||||
const size_t num_rows = batch_size - idx_tile / tilesX * kRegRows;
|
const size_t num_rows = batch_size - row_a;
|
||||||
HWY_ASSERT(num_rows > 0);
|
HWY_DASSERT(num_rows != 0);
|
||||||
switch (num_rows) {
|
switch (num_rows) {
|
||||||
case 1:
|
case 1:
|
||||||
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
|
||||||
strideB, C_stride);
|
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
|
||||||
strideB, C_stride);
|
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
|
||||||
strideB, C_stride);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add,
|
MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||||
idx_tile, tilesX, colsA_rowsB, strideA,
|
|
||||||
strideB, C_stride);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -48,47 +49,10 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
template <typename MatT, size_t kOuter, size_t kInner>
|
|
||||||
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
CompressedArray<MatT, kOuter * kInner> mat;
|
|
||||||
std::array<float, kOuter * kInner> content;
|
|
||||||
const float scale = 1.0f / kInner;
|
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
|
||||||
for (size_t j = 0; j < kInner; j++) {
|
|
||||||
content[i * kInner + j] =
|
|
||||||
static_cast<float>((i * kInner + j + offset) * scale);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Compress(content, ws, mat, pool);
|
|
||||||
mat.set_scale(1.9f); // Arbitrary value, different from 1.
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MatT, size_t kOuter, size_t kInner>
|
|
||||||
CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
CompressedArray<MatT, kOuter * kInner> mat;
|
|
||||||
std::array<MatT, kOuter * kInner> content;
|
|
||||||
|
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
|
|
||||||
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
|
|
||||||
});
|
|
||||||
|
|
||||||
Compress(content, ws, mat, pool);
|
|
||||||
mat.set_scale(1.2f); // Arbitrary value, different from 1.
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MatT, size_t kOuter, size_t kInner>
|
template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
|
||||||
size_t offset, hwy::ThreadPool& pool) {
|
size_t offset, hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
|
|
||||||
new CompressedArray<MatT, kOuter * kInner>);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
const float scale = 1.875f / (kInner * kOuter + offset);
|
const float scale = 1.875f / (kInner * kOuter + offset);
|
||||||
|
|
@ -99,6 +63,8 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
||||||
|
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||||
|
|
@ -109,9 +75,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>
|
||||||
GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
|
|
||||||
new CompressedArray<MatT, kOuter * kInner>);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
const float scale = 1.875f / (kInner * kOuter + offset);
|
const float scale = 1.875f / (kInner * kOuter + offset);
|
||||||
|
|
@ -122,6 +85,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
||||||
|
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
||||||
|
|
@ -133,9 +98,6 @@ template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
|
||||||
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>>(
|
|
||||||
new CompressedArray<MatT, kOuter * kInner>);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
|
|
||||||
|
|
@ -143,22 +105,14 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
|
||||||
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
|
hwy::ZeroBytes(&content[i * kInner], kInner * sizeof(content[0]));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> mat =
|
||||||
|
std::make_unique<CompressedArray<MatT, kOuter * kInner>>();
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t length>
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
|
||||||
HWY_ASSERT(vec);
|
|
||||||
for (size_t idx = 0; idx < length; idx++) {
|
|
||||||
vec[idx] = static_cast<float>(idx + offset);
|
|
||||||
}
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A simple matrix multiplication. No optimization / tiling.
|
// A simple matrix multiplication. No optimization / tiling.
|
||||||
template <size_t kM, size_t kN, size_t kK>
|
template <size_t kM, size_t kN, size_t kK>
|
||||||
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
|
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
|
||||||
|
|
@ -179,27 +133,6 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kOuter, size_t kInner>
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
|
||||||
const CompressedArray<float, kOuter * kInner>& mat,
|
|
||||||
const hwy::AlignedFreeUniquePtr<float[]>& vec,
|
|
||||||
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
|
||||||
HWY_ASSERT(uncompressed_mat && out);
|
|
||||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
|
||||||
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
|
|
||||||
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
|
||||||
out[idx_row] = add[idx_row];
|
|
||||||
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
|
||||||
out[idx_row] +=
|
|
||||||
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
void AssertClose(const MatT* HWY_RESTRICT expected,
|
void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||||
const MatT* HWY_RESTRICT actual, size_t num) {
|
const MatT* HWY_RESTRICT actual, size_t num) {
|
||||||
|
|
@ -233,8 +166,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
|
||||||
// ops_test across instruction sets.
|
|
||||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||||
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
|
|
@ -271,92 +203,167 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
|
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
void PrintSpeed(const char* algo, size_t M, size_t N, size_t K,
|
||||||
typename MatTB = MatTA>
|
double elapsed) {
|
||||||
void TestTiledBatchMatMul() {
|
// * 2 because of FMA.
|
||||||
fprintf(stderr,
|
fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed,
|
||||||
"TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
2E-9 * M * N * K / elapsed);
|
||||||
kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name());
|
|
||||||
hwy::ThreadPool pool(3);
|
|
||||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
|
||||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
|
||||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
|
||||||
std::unique_ptr<CompressedArray<float, kK>> add =
|
|
||||||
GenerateMatHeap<float, 1, kK>(0, pool);
|
|
||||||
add->set_scale(1.0f);
|
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
|
||||||
const float scale = a->scale() * b->scale();
|
|
||||||
|
|
||||||
const double start_slow = hwy::platform::Now();
|
|
||||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
|
|
||||||
kAdd ? add->data() : nullptr, c_slow->data());
|
|
||||||
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
|
|
||||||
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
|
||||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
|
||||||
|
|
||||||
const double start_tiled = hwy::platform::Now();
|
|
||||||
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
|
||||||
MatMul_4x4<kAdd>(kM, a->data(), 0, kN, b_trans->data(), 0, kK, scale, c.get(),
|
|
||||||
kK, add->data(), pool);
|
|
||||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
|
||||||
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
|
|
||||||
|
|
||||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllTiledBatchMatMul() {
|
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
||||||
|
typename MatTB = MatTA>
|
||||||
|
void TestMatMul(hwy::ThreadPool& pool) {
|
||||||
|
using TraitsA = CompressTraits<MatTA>;
|
||||||
|
using TraitsB = CompressTraits<MatTB>;
|
||||||
|
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", kM,
|
||||||
|
kN, kK, kAdd, TraitsA::Name(), TraitsB::Name());
|
||||||
|
|
||||||
|
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||||
|
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||||
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||||
|
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
|
|
||||||
|
const float scale = a->scale() * b_trans->scale();
|
||||||
|
std::unique_ptr<CompressedArray<float, kK>> add;
|
||||||
|
if (kAdd) {
|
||||||
|
add = GenerateMatHeap<float, 1, kK>(0, pool);
|
||||||
|
add->set_scale(1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow;
|
||||||
|
const bool compare_slow = kN < 2048;
|
||||||
|
if (compare_slow) {
|
||||||
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||||
|
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
HWY_ASSERT_EQ(scale, a->scale() * b->scale());
|
||||||
|
c_slow = GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
const double start_slow = hwy::platform::Now();
|
||||||
|
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
|
||||||
|
kAdd ? add->data() : nullptr, c_slow->data());
|
||||||
|
PrintSpeed("MatMulSlowBatch", kM, kN, kK,
|
||||||
|
hwy::platform::Now() - start_slow);
|
||||||
|
}
|
||||||
|
|
||||||
|
double min_elapsed = hwy::HighestValue<double>();
|
||||||
|
for (int rep = 0; rep < (compare_slow ? 1 : 3); ++rep) {
|
||||||
|
const double start_tiled = hwy::platform::Now();
|
||||||
|
MatMul_4x4<kAdd>(kM, MakeMat(a->data(), kN), MakeMat(b_trans->data(), kN),
|
||||||
|
scale, kAdd ? add->data_scale1() : nullptr,
|
||||||
|
MakeMat(c.get(), kK), pool);
|
||||||
|
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
||||||
|
}
|
||||||
|
PrintSpeed("MatMul_4x4", kM, kN, kK, min_elapsed);
|
||||||
|
|
||||||
|
if (compare_slow) {
|
||||||
|
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllMatMul() {
|
||||||
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
|
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
||||||
|
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::ThreadPool pool(4);
|
||||||
using BF16 = hwy::bfloat16_t;
|
using BF16 = hwy::bfloat16_t;
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
using SFP = SfpStream;
|
using SFP = SfpStream;
|
||||||
// medium-sized square test
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
|
|
||||||
|
|
||||||
// minimal non-square test. kK must be at least 2 vectors.
|
|
||||||
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
|
|
||||||
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
|
|
||||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
|
||||||
|
|
||||||
// large-scale test
|
// large-scale test
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool);
|
||||||
// Causes test timeout.
|
|
||||||
// TestTiledBatchMatMul<512, 24576, 3072, float>();
|
// medium-sized square test
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool);
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool);
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool);
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool);
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool);
|
||||||
|
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool);
|
||||||
|
|
||||||
|
// minimal non-square test. kK must be at least 2 vectors.
|
||||||
|
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||||
|
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||||
|
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||||
|
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||||
|
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||||
|
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool);
|
||||||
|
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||||
|
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool);
|
||||||
|
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||||
|
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kOuter, size_t kInner>
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||||
|
const CompressedArray<float, kOuter * kInner>& mat,
|
||||||
|
const hwy::AlignedFreeUniquePtr<float[]>& vec,
|
||||||
|
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
||||||
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
||||||
|
HWY_ASSERT(uncompressed_mat && out);
|
||||||
|
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||||
|
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
|
||||||
|
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
||||||
|
out[idx_row] = add[idx_row];
|
||||||
|
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
||||||
|
out[idx_row] +=
|
||||||
|
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
|
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
CompressedArray<MatT, kOuter * kInner> mat;
|
||||||
|
std::array<float, kOuter * kInner> content;
|
||||||
|
const float scale = 1.0f / kInner;
|
||||||
|
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||||
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
|
content[i * kInner + j] =
|
||||||
|
static_cast<float>((i * kInner + j + offset) * scale);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Compress(content, ws, mat, pool);
|
||||||
|
mat.set_scale(1.9f); // Arbitrary value, different from 1.
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t length>
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
||||||
|
HWY_ASSERT(vec);
|
||||||
|
for (size_t idx = 0; idx < length; idx++) {
|
||||||
|
vec[idx] = static_cast<float>(idx + offset);
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
|
|
@ -441,7 +448,7 @@ HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
HWY_BEFORE_TEST(MatmulTest);
|
HWY_BEFORE_TEST(MatmulTest);
|
||||||
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue