mirror of https://github.com/google/gemma.cpp.git
Support mixed (bf16, sfp) tiled MatMul. Same sfp-decompress strategy as in (f32,
sfp) tiled MatMul. PiperOrigin-RevId: 642901844
This commit is contained in:
parent
6e67a6d8a9
commit
38eb452b94
112
gemma/ops.h
112
gemma/ops.h
|
|
@ -345,9 +345,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// As above, for SfpStream.
|
||||
template <size_t kColsA_RowsB>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A,
|
||||
// Same as above, but with mixed Mat types: (f32, sfp)).
|
||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const SfpStream* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
|
|
@ -359,7 +359,7 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A,
|
|||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
|
||||
|
|
@ -383,7 +383,7 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A,
|
|||
V c32 = hn::Zero(d);
|
||||
V c33 = hn::Zero(d);
|
||||
|
||||
const float* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
|
||||
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
|
||||
|
|
@ -432,7 +432,103 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A,
|
|||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types.
|
||||
// Same as above, but with mixed Mat types: (bf16, sfp)).
|
||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_BF16(MatTA)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const SfpStream* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
const size_t N = Lanes(d32);
|
||||
using V = hn::Vec<decltype(d32)>;
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTA, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
V c00 = hn::Zero(d32);
|
||||
V c01 = hn::Zero(d32);
|
||||
V c02 = hn::Zero(d32);
|
||||
V c03 = hn::Zero(d32);
|
||||
|
||||
V c10 = hn::Zero(d32);
|
||||
V c11 = hn::Zero(d32);
|
||||
V c12 = hn::Zero(d32);
|
||||
V c13 = hn::Zero(d32);
|
||||
|
||||
V c20 = hn::Zero(d32);
|
||||
V c21 = hn::Zero(d32);
|
||||
V c22 = hn::Zero(d32);
|
||||
V c23 = hn::Zero(d32);
|
||||
|
||||
V c30 = hn::Zero(d32);
|
||||
V c31 = hn::Zero(d32);
|
||||
V c32 = hn::Zero(d32);
|
||||
V c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
|
||||
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
|
||||
CompressTraits<SfpStream>::Decompress(
|
||||
d32,
|
||||
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
|
||||
kRegRows * kColsA_RowsB);
|
||||
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
const V b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const V a0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
|
||||
const V a1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
|
||||
const V a2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
|
||||
const V a3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21,
|
||||
c22, c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (f32, bf16).
|
||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA),
|
||||
typename MatTB, HWY_IF_BF16(MatTB)>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
|
|
@ -606,8 +702,8 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
|||
}
|
||||
}
|
||||
|
||||
template <size_t kM, size_t kN, size_t kK>
|
||||
HWY_INLINE void MatMulSlow(const float* HWY_RESTRICT a,
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA>
|
||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
||||
float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
|
|
|
|||
|
|
@ -555,12 +555,14 @@ void TestAllTiledMatMul() {
|
|||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, float, SfpStream>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// minimal non-square test
|
||||
TestTiledMatMul<4, 128, 4, float>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<32, 128, 32, float, SfpStream>();
|
||||
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices
|
||||
|
|
|
|||
Loading…
Reference in New Issue