Implement a missing (bf16, f32) tiled MatMul kernel.

PiperOrigin-RevId: 643313676
This commit is contained in:
Andrey Vlasov 2024-06-14 04:54:14 -07:00 committed by Copybara-Service
parent d3c6a45b59
commit b17631c95f
2 changed files with 130 additions and 24 deletions

View File

@ -362,7 +362,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (f32, sfp)).
// Same as above, but with mixed Mat types: (f32, sfp).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
@ -455,7 +455,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (bf16, sfp)).
// Same as above, but with mixed Mat types: (bf16, sfp).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
@ -656,6 +656,104 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (bf16, f32).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> d32;
using VF = hn::Vec<decltype(d32)>;
// TODO: Using half-vectors for now, it might be faster to
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
// accordingly.
const hn::Rebind<MatTA, decltype(d32)> d16;
HWY_DASSERT(Lanes(d16) == Lanes(d32));
const size_t N = Lanes(d16);
VF c00 = hn::Zero(d32);
VF c01 = hn::Zero(d32);
VF c02 = hn::Zero(d32);
VF c03 = hn::Zero(d32);
VF c10 = hn::Zero(d32);
VF c11 = hn::Zero(d32);
VF c12 = hn::Zero(d32);
VF c13 = hn::Zero(d32);
VF c20 = hn::Zero(d32);
VF c21 = hn::Zero(d32);
VF c22 = hn::Zero(d32);
VF c23 = hn::Zero(d32);
VF c30 = hn::Zero(d32);
VF c31 = hn::Zero(d32);
VF c32 = hn::Zero(d32);
VF c33 = hn::Zero(d32);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
// Promote bf16 to f32
const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
const VF a0 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if (kNumRows == 1) continue;
const VF a1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if (kNumRows == 2) continue;
const VF a2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if (kNumRows == 3) continue;
const VF a3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
stride_c);
}
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
// This function loops over all tiles (static scheduling). TODO(janwas): we can

View File

@ -521,8 +521,8 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n",
idx, expected_value, idx, actual_value, tolerance);
HWY_ASSERT(0);
}
}
@ -558,6 +558,7 @@ void TestAllTiledMatMul() {
TestTiledMatMul<512, 512, 512, float>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>();
TestTiledMatMul<512, 512, 512, float, SfpStream>();
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
@ -565,6 +566,7 @@ void TestAllTiledMatMul() {
TestTiledMatMul<4, 128, 4, float>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>();
TestTiledMatMul<32, 128, 32, float, SfpStream>();
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
@ -577,6 +579,7 @@ void TestAllTiledMatMul() {
template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK);
hwy::ThreadPool pool(3);
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
@ -600,33 +603,39 @@ void TestAllTiledBatchMatMul() {
TestTiledBatchMatMul<512, 512, 512, float>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
// minimal non-square test
TestTiledBatchMatMul<35, 128, 4, float>();
TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<35, 128, 32, float>();
TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<4, 128, 4, float>();
TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 4, float>();
TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, float>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<4, 128, 8, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, float>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<2, 128, 4, float>();
TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<1, 128, 4, float>();
TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, float>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<2, 128, 16, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, float>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
}
@ -730,7 +739,6 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(OpsTest);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);