mirror of https://github.com/google/gemma.cpp.git
Implement a missing (bf16, f32) tiled MatMul kernel.
PiperOrigin-RevId: 643313676
This commit is contained in:
parent
d3c6a45b59
commit
b17631c95f
102
gemma/ops.h
102
gemma/ops.h
|
|
@ -362,7 +362,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (f32, sfp)).
|
||||
// Same as above, but with mixed Mat types: (f32, sfp).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||
HWY_IF_F32(MatTA)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
|
|
@ -455,7 +455,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (bf16, sfp)).
|
||||
// Same as above, but with mixed Mat types: (bf16, sfp).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||
HWY_IF_BF16(MatTA)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
|
|
@ -656,6 +656,104 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (bf16, f32).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
|
||||
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_F32(MatTB)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
using VF = hn::Vec<decltype(d32)>;
|
||||
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTA, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
const size_t N = Lanes(d16);
|
||||
|
||||
VF c00 = hn::Zero(d32);
|
||||
VF c01 = hn::Zero(d32);
|
||||
VF c02 = hn::Zero(d32);
|
||||
VF c03 = hn::Zero(d32);
|
||||
|
||||
VF c10 = hn::Zero(d32);
|
||||
VF c11 = hn::Zero(d32);
|
||||
VF c12 = hn::Zero(d32);
|
||||
VF c13 = hn::Zero(d32);
|
||||
|
||||
VF c20 = hn::Zero(d32);
|
||||
VF c21 = hn::Zero(d32);
|
||||
VF c22 = hn::Zero(d32);
|
||||
VF c23 = hn::Zero(d32);
|
||||
|
||||
VF c30 = hn::Zero(d32);
|
||||
VF c31 = hn::Zero(d32);
|
||||
VF c32 = hn::Zero(d32);
|
||||
VF c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
// Promote bf16 to f32
|
||||
const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
|
||||
const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
|
||||
const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
|
||||
const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const VF a0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if (kNumRows == 1) continue;
|
||||
|
||||
const VF a1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if (kNumRows == 2) continue;
|
||||
|
||||
const VF a2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if (kNumRows == 3) continue;
|
||||
|
||||
const VF a3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
|
||||
stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function loops over all tiles (static scheduling). TODO(janwas): we can
|
||||
|
|
|
|||
|
|
@ -521,8 +521,8 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
|||
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
|
||||
expected_value, idx, actual_value);
|
||||
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n",
|
||||
idx, expected_value, idx, actual_value, tolerance);
|
||||
HWY_ASSERT(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -558,6 +558,7 @@ void TestAllTiledMatMul() {
|
|||
TestTiledMatMul<512, 512, 512, float>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>();
|
||||
TestTiledMatMul<512, 512, 512, float, SfpStream>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
|
|
@ -565,6 +566,7 @@ void TestAllTiledMatMul() {
|
|||
TestTiledMatMul<4, 128, 4, float>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>();
|
||||
TestTiledMatMul<32, 128, 32, float, SfpStream>();
|
||||
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
|
|
@ -577,6 +579,7 @@ void TestAllTiledMatMul() {
|
|||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledBatchMatMul() {
|
||||
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK);
|
||||
hwy::ThreadPool pool(3);
|
||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||
|
|
@ -600,33 +603,39 @@ void TestAllTiledBatchMatMul() {
|
|||
TestTiledBatchMatMul<512, 512, 512, float>();
|
||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<512, 512, 512, float, SfpStream>();
|
||||
TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// minimal non-square test
|
||||
TestTiledBatchMatMul<35, 128, 4, float>();
|
||||
TestTiledBatchMatMul<34, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<33, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<35, 128, 32, float>();
|
||||
TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<31, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<4, 128, 4, float>();
|
||||
TestTiledBatchMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<4, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<3, 128, 4, float>();
|
||||
TestTiledBatchMatMul<3, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<3, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 8, float>();
|
||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<4, 128, 8, float, SfpStream>();
|
||||
TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<3, 128, 32, float>();
|
||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<3, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<2, 128, 4, float>();
|
||||
TestTiledBatchMatMul<2, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<2, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 4, float>();
|
||||
TestTiledBatchMatMul<1, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<1, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 16, float>();
|
||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<2, 128, 16, float, SfpStream>();
|
||||
TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 32, float>();
|
||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
}
|
||||
|
|
@ -730,7 +739,6 @@ HWY_AFTER_NAMESPACE();
|
|||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(OpsTest);
|
||||
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue