diff --git a/gemma/ops.h b/gemma/ops.h index 0bc2083..93a9a4b 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -753,71 +753,6 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c20, c21, c22, c23, c30, c31, c32, c33, tile_c, stride_c); } - -// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and -// kColsBC is 24k or 3k. Note: B is transposed (column-major). -// This function loops over all tiles (static scheduling). TODO(janwas): we can -// possibly remove this if ThreadPool(0) is as efficient as the loop. -template -void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, - MatT* HWY_RESTRICT C) { - const hn::ScalableTag d; - const size_t N = hn::Lanes(d); // column step size - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; // in vectors - - static_assert(kRowsAC % kRegRows == 0); - static_assert(kColsBC % kRegCols == 0); - HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0); - constexpr size_t kTilesY = kRowsAC / kRegRows; - constexpr size_t kTilesX = kColsBC / kRegCols; - constexpr size_t kTiles = kTilesX * kTilesY; - - constexpr size_t kStrideA = kColsA_RowsB; - constexpr size_t kStrideB = kColsA_RowsB; // B is column-major - constexpr size_t kStrideC = kColsBC; - - HWY_UNROLL(1) - for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) { - GEMM_4x4_Tile( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); - } -} - -// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and -// kColsBC is 24k or 3k. Note: B is transposed (column-major). -// This function processes tiles in parallel with a work-stealing thread pool. -template -HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A, - const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C, - hwy::ThreadPool& pool) { - // Process reg-sized tiles of C in parallel. We currently write C directly, - // which touches more memory than fits in L3. TODO: add another level of loops - // so that we finish one L3-sized piece of C at a time. - const hn::ScalableTag d; - const size_t N = Lanes(d); - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; // in vectors - - static_assert(kRowsAC % kRegRows == 0); - static_assert(kColsBC % kRegCols == 0); - HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0); - const size_t kTilesY = kRowsAC / kRegRows; - const size_t kTilesX = kColsBC / kRegCols; - const size_t kTiles = kTilesX * kTilesY; - - constexpr size_t kStrideA = kColsA_RowsB; - constexpr size_t kStrideB = kColsA_RowsB; - constexpr size_t kStrideC = kColsBC; - - pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { - // Computes the finished product of one 4x4N tile and writes to C. - GEMM_4x4_Tile( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); - }); -} - // Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, // and kColsBC is 24k or 3k. Note: B is transposed (column-major). // NOTE that batch_size is the number of rows of A and C. @@ -869,35 +804,6 @@ HWY_NOINLINE void MatMul_4x4_Batch( }); } -// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on -// ops_test across instruction sets. -template -HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b, - float* HWY_RESTRICT out) { - for (size_t i = 0; i < kM; ++i) { - for (size_t k = 0; k < kN; ++k) { - for (size_t j = 0; j < kK; ++j) { - const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); - const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); - out[i * kK + j] += a1 * b1; - } - } - } -} - -template -HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, - const SfpStream* HWY_RESTRICT b_sfp_stream, - float* HWY_RESTRICT out) { - const hn::ScalableTag d; - hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); - CompressTraits::Decompress(d, - /*in_capacity=*/0, b_sfp_stream, 0, - b.get(), kK * kN); - MatMulSlow(a, b.get(), out); -} - // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // ops_test across instruction sets. template diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 930c3a1..c9efde1 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -528,54 +528,6 @@ void AssertClose(const MatT* HWY_RESTRICT expected, } } -template -void TestTiledMatMul() { - hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); - std::unique_ptr> a = - GenerateMatHeap(0, pool); - std::unique_ptr> b = - GenerateMatHeap(0, pool); - std::unique_ptr> c_slow = - GenerateZeroMatHeap(pool); - std::unique_ptr> c_slow_batch = - GenerateZeroMatHeap(pool); - - MatMulSlow(a->data(), b->data(), c_slow->data()); - MatMulSlowBatch(kM, a->data(), b->data(), c_slow_batch->data()); - AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK); - - hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); - std::unique_ptr> b_trans = - GenerateTransposeMatHeap(0, pool); - MatMul_4x4(a->data(), b_trans->data(), c.get(), pool); - - AssertClose(c_slow->data(), c.get(), kM * kK); -} - -void TestAllTiledMatMul() { - // medium-sized square test - TestTiledMatMul<512, 512, 512, float>(); - TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>(); - TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>(); - TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>(); - TestTiledMatMul<512, 512, 512, float, SfpStream>(); - TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); - - // minimal non-square test - TestTiledMatMul<4, 128, 4, float>(); - TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>(); - TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>(); - TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>(); - TestTiledMatMul<32, 128, 32, float, SfpStream>(); - TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>(); - - // large-scale test - // TODO(philculliton): investigate rounding issues with large matrices. - // Causes test timeout. - // TestTiledMatMul<512, 24576, 3072, float>(); -} - template void TestTiledBatchMatMul() { @@ -638,6 +590,11 @@ void TestAllTiledBatchMatMul() { TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<1, 128, 32, float, SfpStream>(); TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>(); + + // large-scale test + // TODO(philculliton): investigate rounding issues with large matrices. + // Causes test timeout. + // TestTiledBatchMatMul<512, 24576, 3072, float>(); } void TestMatVecAdd() { @@ -746,7 +703,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);