mirror of https://github.com/google/gemma.cpp.git
Implement mixed mode matmul: f32 * bf16
PiperOrigin-RevId: 640940962
This commit is contained in:
parent
57c2cd8b52
commit
39d4115717
120
gemma/ops.h
120
gemma/ops.h
|
|
@ -343,6 +343,99 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types.
|
||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA),
|
||||
typename MatTB, HWY_IF_BF16(MatTB)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
using VF = hn::Vec<decltype(d32)>;
|
||||
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTB, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
const size_t N = Lanes(d16);
|
||||
|
||||
VF c00 = hn::Zero(d32);
|
||||
VF c01 = hn::Zero(d32);
|
||||
VF c02 = hn::Zero(d32);
|
||||
VF c03 = hn::Zero(d32);
|
||||
|
||||
VF c10 = hn::Zero(d32);
|
||||
VF c11 = hn::Zero(d32);
|
||||
VF c12 = hn::Zero(d32);
|
||||
VF c13 = hn::Zero(d32);
|
||||
|
||||
VF c20 = hn::Zero(d32);
|
||||
VF c21 = hn::Zero(d32);
|
||||
VF c22 = hn::Zero(d32);
|
||||
VF c23 = hn::Zero(d32);
|
||||
|
||||
VF c30 = hn::Zero(d32);
|
||||
VF c31 = hn::Zero(d32);
|
||||
VF c32 = hn::Zero(d32);
|
||||
VF c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
// Promote bf16 to f32
|
||||
const VF b0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 0 + col_ab));
|
||||
const VF b1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 1 + col_ab));
|
||||
const VF b2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 2 + col_ab));
|
||||
const VF b3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 3 + col_ab));
|
||||
|
||||
const VF a0 = hn::LoadU(d32, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
|
||||
const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
|
||||
const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
|
||||
const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function loops over all tiles (static scheduling). TODO(janwas): we can
|
||||
|
|
@ -376,15 +469,15 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
|||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT,
|
||||
typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4(const MatT* HWY_RESTRICT A,
|
||||
const MatT* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||
typename MatTB, typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
|
||||
hwy::ThreadPool& pool) {
|
||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||
// so that we finish one L3-sized piece of C at a time.
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
|
@ -409,9 +502,9 @@ HWY_NOINLINE void MatMul_4x4(const MatT* HWY_RESTRICT A,
|
|||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatT>
|
||||
HWY_INLINE void MatMulSlow(const MatT* HWY_RESTRICT a,
|
||||
const MatT* HWY_RESTRICT b,
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t i = 0; i < kM; ++i) {
|
||||
for (size_t k = 0; k < kN; ++k) {
|
||||
|
|
@ -1154,14 +1247,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
hn::Transform(d, x, mask_pos,
|
||||
[&vmax](const auto d, const auto value) HWY_ATTR {
|
||||
return hn::Exp(d, hn::Sub(value, vmax));
|
||||
});
|
||||
[&vmax](const auto d, const auto value)
|
||||
HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); });
|
||||
|
||||
auto sum = hn::Zero(d);
|
||||
Foreach(d, x, mask_pos, sum,
|
||||
[&sum](const auto d, const auto value)
|
||||
HWY_ATTR { sum = hn::Add(sum, value); });
|
||||
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {
|
||||
sum = hn::Add(sum, value);
|
||||
});
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
|
|
|
|||
|
|
@ -475,21 +475,21 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename MatT>
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
void TestTiledMatMul() {
|
||||
hwy::ThreadPool pool(3);
|
||||
constexpr size_t kM = 512; // 384
|
||||
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||
constexpr size_t kK = 512; // * 5; // 640
|
||||
|
||||
CompressedArray<MatT, kM * kN> a = GenerateMat<MatT, kM, kN>(0, pool);
|
||||
CompressedArray<MatT, kN * kK> b = GenerateMat<MatT, kN, kK>(0, pool);
|
||||
CompressedArray<MatTA, kM * kN> a = GenerateMat<MatTA, kM, kN>(0, pool);
|
||||
CompressedArray<MatTB, kN * kK> b = GenerateMat<MatTB, kN, kK>(0, pool);
|
||||
CompressedArray<float, kM * kK> c_slow = GenerateZeroMat<float, kM, kK>(pool);
|
||||
MatMulSlow<kM, kN, kK>(a.data(), b.data(), c_slow.data());
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||
CompressedArray<MatT, kN * kK> b_trans =
|
||||
GenerateTransposeMat<MatT, kN, kK>(0, pool);
|
||||
CompressedArray<MatTB, kN * kK> b_trans =
|
||||
GenerateTransposeMat<MatTB, kN, kK>(0, pool);
|
||||
MatMul_4x4<kM, kN, kK>(a.data(), b_trans.data(), c.get(), pool);
|
||||
|
||||
AssertClose(c_slow.data(), c.get(), kM * kK);
|
||||
|
|
@ -498,6 +498,7 @@ void TestTiledMatMul() {
|
|||
void TestAllTiledMatMul() {
|
||||
TestTiledMatMul<float>();
|
||||
TestTiledMatMul<hwy::bfloat16_t>();
|
||||
TestTiledMatMul<float, hwy::bfloat16_t>();
|
||||
// TODO(janwas): SFP
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue