diff --git a/gemma/ops.h b/gemma/ops.h index 4b54378..691bb2c 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -343,6 +343,99 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } +// Same as above, but with mixed Mat types. +template +HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, + const MatTB* HWY_RESTRICT B, + float* HWY_RESTRICT C, const size_t idx_tile, + const size_t xtiles, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { + constexpr size_t kRegRows = 4; + constexpr size_t kRegCols = 4; + + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. + const size_t row_a = idx_tile / xtiles * kRegRows; + const size_t row_b_col_c = idx_tile % xtiles * kRegCols; + + const hn::ScalableTag d32; + using VF = hn::Vec; + + // TODO: Using half-vectors for now, it might be faster to + // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B + // accordingly. + const hn::Rebind d16; + HWY_DASSERT(Lanes(d16) == Lanes(d32)); + + const size_t N = Lanes(d16); + + VF c00 = hn::Zero(d32); + VF c01 = hn::Zero(d32); + VF c02 = hn::Zero(d32); + VF c03 = hn::Zero(d32); + + VF c10 = hn::Zero(d32); + VF c11 = hn::Zero(d32); + VF c12 = hn::Zero(d32); + VF c13 = hn::Zero(d32); + + VF c20 = hn::Zero(d32); + VF c21 = hn::Zero(d32); + VF c22 = hn::Zero(d32); + VF c23 = hn::Zero(d32); + + VF c30 = hn::Zero(d32); + VF c31 = hn::Zero(d32); + VF c32 = hn::Zero(d32); + VF c33 = hn::Zero(d32); + + const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; + const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; + + // Loop over columns of A and columns of the transposed B, in steps of N. + // Accumulates into the c## vectors. + HWY_UNROLL(1) + for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { + // Promote bf16 to f32 + const VF b0 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 0 + col_ab)); + const VF b1 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 1 + col_ab)); + const VF b2 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 2 + col_ab)); + const VF b3 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 3 + col_ab)); + + const VF a0 = hn::LoadU(d32, tile_a + stride_a * 0 + col_ab); + c00 = hn::MulAdd(a0, b0, c00); + c01 = hn::MulAdd(a0, b1, c01); + c02 = hn::MulAdd(a0, b2, c02); + c03 = hn::MulAdd(a0, b3, c03); + + const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab); + c10 = hn::MulAdd(a1, b0, c10); + c11 = hn::MulAdd(a1, b1, c11); + c12 = hn::MulAdd(a1, b2, c12); + c13 = hn::MulAdd(a1, b3, c13); + + const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab); + c20 = hn::MulAdd(a2, b0, c20); + c21 = hn::MulAdd(a2, b1, c21); + c22 = hn::MulAdd(a2, b2, c22); + c23 = hn::MulAdd(a2, b3, c23); + + const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab); + c30 = hn::MulAdd(a3, b0, c30); + c31 = hn::MulAdd(a3, b1, c31); + c32 = hn::MulAdd(a3, b2, c32); + c33 = hn::MulAdd(a3, b3, c33); + } + + float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, + c22, c23, c30, c31, c32, c33, tile_c, stride_c); +} + // Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and // kColsBC is 24k or 3k. Note: B is transposed (column-major). // This function loops over all tiles (static scheduling). TODO(janwas): we can @@ -376,15 +469,15 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, // Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and // kColsBC is 24k or 3k. Note: B is transposed (column-major). // This function processes tiles in parallel with a work-stealing thread pool. -template -HWY_NOINLINE void MatMul_4x4(const MatT* HWY_RESTRICT A, - const MatT* HWY_RESTRICT B, OutT* HWY_RESTRICT C, +template +HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A, + const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { // Process reg-sized tiles of C in parallel. We currently write C directly, // which touches more memory than fits in L3. TODO: add another level of loops // so that we finish one L3-sized piece of C at a time. - const hn::ScalableTag d; + const hn::ScalableTag d; const size_t N = Lanes(d); constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; // in vectors @@ -409,9 +502,9 @@ HWY_NOINLINE void MatMul_4x4(const MatT* HWY_RESTRICT A, // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // ops_test across instruction sets. -template -HWY_INLINE void MatMulSlow(const MatT* HWY_RESTRICT a, - const MatT* HWY_RESTRICT b, +template +HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, + const MatTB* HWY_RESTRICT b, float* HWY_RESTRICT out) { for (size_t i = 0; i < kM; ++i) { for (size_t k = 0; k < kN; ++k) { @@ -1154,14 +1247,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // Subtract max (avoid precision loss for large exponents) and exponentiate. hn::Transform(d, x, mask_pos, - [&vmax](const auto d, const auto value) HWY_ATTR { - return hn::Exp(d, hn::Sub(value, vmax)); - }); + [&vmax](const auto d, const auto value) + HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); }); auto sum = hn::Zero(d); - Foreach(d, x, mask_pos, sum, - [&sum](const auto d, const auto value) - HWY_ATTR { sum = hn::Add(sum, value); }); + Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR { + sum = hn::Add(sum, value); + }); // Normalize to probability distribution const float mul = 1.0f / hn::ReduceSum(d, sum); diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index dd0b052..a7f47f4 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -475,21 +475,21 @@ void AssertClose(const MatT* HWY_RESTRICT expected, } } -template +template void TestTiledMatMul() { hwy::ThreadPool pool(3); constexpr size_t kM = 512; // 384 constexpr size_t kN = 512; // * 5; // 6; // 768 constexpr size_t kK = 512; // * 5; // 640 - CompressedArray a = GenerateMat(0, pool); - CompressedArray b = GenerateMat(0, pool); + CompressedArray a = GenerateMat(0, pool); + CompressedArray b = GenerateMat(0, pool); CompressedArray c_slow = GenerateZeroMat(pool); MatMulSlow(a.data(), b.data(), c_slow.data()); hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); - CompressedArray b_trans = - GenerateTransposeMat(0, pool); + CompressedArray b_trans = + GenerateTransposeMat(0, pool); MatMul_4x4(a.data(), b_trans.data(), c.get(), pool); AssertClose(c_slow.data(), c.get(), kM * kK); @@ -498,6 +498,7 @@ void TestTiledMatMul() { void TestAllTiledMatMul() { TestTiledMatMul(); TestTiledMatMul(); + TestTiledMatMul(); // TODO(janwas): SFP }