From b17631c95f229088853e00e58102ebd245d5ae99 Mon Sep 17 00:00:00 2001 From: Andrey Vlasov Date: Fri, 14 Jun 2024 04:54:14 -0700 Subject: [PATCH] Implement a missing (bf16, f32) tiled MatMul kernel. PiperOrigin-RevId: 643313676 --- gemma/ops.h | 102 +++++++++++++++++++++++++++++++++++++++++++++- gemma/ops_test.cc | 52 +++++++++++++---------- 2 files changed, 130 insertions(+), 24 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index dc00da6..0bc2083 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -362,7 +362,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// Same as above, but with mixed Mat types: (f32, sfp)). +// Same as above, but with mixed Mat types: (f32, sfp). template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, @@ -455,7 +455,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// Same as above, but with mixed Mat types: (bf16, sfp)). +// Same as above, but with mixed Mat types: (bf16, sfp). template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, @@ -656,6 +656,104 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c22, c23, c30, c31, c32, c33, tile_c, stride_c); } +// Same as above, but with mixed Mat types: (bf16, f32). +template +HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, + const MatTB* HWY_RESTRICT B, + float* HWY_RESTRICT C, const size_t idx_tile, + const size_t xtiles, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { + constexpr size_t kRegRows = 4; + constexpr size_t kRegCols = 4; + static_assert(kNumRows <= kRegRows); + + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. + const size_t row_a = idx_tile / xtiles * kRegRows; + const size_t row_b_col_c = idx_tile % xtiles * kRegCols; + + const hn::ScalableTag d32; + using VF = hn::Vec; + + // TODO: Using half-vectors for now, it might be faster to + // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B + // accordingly. + const hn::Rebind d16; + HWY_DASSERT(Lanes(d16) == Lanes(d32)); + + const size_t N = Lanes(d16); + + VF c00 = hn::Zero(d32); + VF c01 = hn::Zero(d32); + VF c02 = hn::Zero(d32); + VF c03 = hn::Zero(d32); + + VF c10 = hn::Zero(d32); + VF c11 = hn::Zero(d32); + VF c12 = hn::Zero(d32); + VF c13 = hn::Zero(d32); + + VF c20 = hn::Zero(d32); + VF c21 = hn::Zero(d32); + VF c22 = hn::Zero(d32); + VF c23 = hn::Zero(d32); + + VF c30 = hn::Zero(d32); + VF c31 = hn::Zero(d32); + VF c32 = hn::Zero(d32); + VF c33 = hn::Zero(d32); + + const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; + const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; + + // Loop over columns of A and columns of the transposed B, in steps of N. + // Accumulates into the c## vectors. + HWY_UNROLL(1) + for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { + // Promote bf16 to f32 + const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab); + const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab); + const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab); + const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab); + + const VF a0 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab)); + c00 = hn::MulAdd(a0, b0, c00); + c01 = hn::MulAdd(a0, b1, c01); + c02 = hn::MulAdd(a0, b2, c02); + c03 = hn::MulAdd(a0, b3, c03); + if (kNumRows == 1) continue; + + const VF a1 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); + c10 = hn::MulAdd(a1, b0, c10); + c11 = hn::MulAdd(a1, b1, c11); + c12 = hn::MulAdd(a1, b2, c12); + c13 = hn::MulAdd(a1, b3, c13); + if (kNumRows == 2) continue; + + const VF a2 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); + c20 = hn::MulAdd(a2, b0, c20); + c21 = hn::MulAdd(a2, b1, c21); + c22 = hn::MulAdd(a2, b2, c22); + c23 = hn::MulAdd(a2, b3, c23); + if (kNumRows == 3) continue; + + const VF a3 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); + c30 = hn::MulAdd(a3, b0, c30); + c31 = hn::MulAdd(a3, b1, c31); + c32 = hn::MulAdd(a3, b2, c32); + c33 = hn::MulAdd(a3, b3, c33); + } + + float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, tile_c, + stride_c); +} + // Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and // kColsBC is 24k or 3k. Note: B is transposed (column-major). // This function loops over all tiles (static scheduling). TODO(janwas): we can diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index f027364..930c3a1 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -521,8 +521,8 @@ void AssertClose(const MatT* HWY_RESTRICT expected, if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { - fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, - expected_value, idx, actual_value); + fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n", + idx, expected_value, idx, actual_value, tolerance); HWY_ASSERT(0); } } @@ -558,6 +558,7 @@ void TestAllTiledMatMul() { TestTiledMatMul<512, 512, 512, float>(); TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>(); + TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>(); TestTiledMatMul<512, 512, 512, float, SfpStream>(); TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); @@ -565,6 +566,7 @@ void TestAllTiledMatMul() { TestTiledMatMul<4, 128, 4, float>(); TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>(); TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>(); + TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>(); TestTiledMatMul<32, 128, 32, float, SfpStream>(); TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>(); @@ -577,6 +579,7 @@ void TestAllTiledMatMul() { template void TestTiledBatchMatMul() { + fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK); hwy::ThreadPool pool(3); std::unique_ptr> a = GenerateMatHeap(0, pool); @@ -600,33 +603,39 @@ void TestAllTiledBatchMatMul() { TestTiledBatchMatMul<512, 512, 512, float>(); TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>(); TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<512, 512, 512, float, SfpStream>(); TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); // minimal non-square test - TestTiledBatchMatMul<35, 128, 4, float>(); - TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>(); - TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<35, 128, 32, float>(); + TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>(); + TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<31, 128, 32, float, SfpStream>(); TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<4, 128, 4, float>(); - TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 32, float, SfpStream>(); - TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<3, 128, 4, float>(); - TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>(); - TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 8, float>(); + TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<4, 128, 8, float, SfpStream>(); + TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<3, 128, 32, float>(); + TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<3, 128, 32, float, SfpStream>(); TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<2, 128, 4, float>(); - TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 32, float, SfpStream>(); - TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<1, 128, 4, float>(); - TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>(); - TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 16, float>(); + TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<2, 128, 16, float, SfpStream>(); + TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<1, 128, 32, float>(); + TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<1, 128, 32, float, SfpStream>(); TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>(); } @@ -730,7 +739,6 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(OpsTest); - HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);