mirror of https://github.com/google/gemma.cpp.git
Removed now redundant non-batch matmul
PiperOrigin-RevId: 643317187
This commit is contained in:
parent
b17631c95f
commit
198326a682
94
gemma/ops.h
94
gemma/ops.h
|
|
@ -753,71 +753,6 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
|
||||
stride_c);
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function loops over all tiles (static scheduling). TODO(janwas): we can
|
||||
// possibly remove this if ThreadPool(0) is as efficient as the loop.
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT>
|
||||
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
||||
MatT* HWY_RESTRICT C) {
|
||||
const hn::ScalableTag<MatT> d;
|
||||
const size_t N = hn::Lanes(d); // column step size
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
||||
static_assert(kRowsAC % kRegRows == 0);
|
||||
static_assert(kColsBC % kRegCols == 0);
|
||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
||||
constexpr size_t kTilesY = kRowsAC / kRegRows;
|
||||
constexpr size_t kTilesX = kColsBC / kRegCols;
|
||||
constexpr size_t kTiles = kTilesX * kTilesY;
|
||||
|
||||
constexpr size_t kStrideA = kColsA_RowsB;
|
||||
constexpr size_t kStrideB = kColsA_RowsB; // B is column-major
|
||||
constexpr size_t kStrideC = kColsBC;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
|
||||
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
}
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||
typename MatTB, typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
|
||||
hwy::ThreadPool& pool) {
|
||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||
// so that we finish one L3-sized piece of C at a time.
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4; // in vectors
|
||||
|
||||
static_assert(kRowsAC % kRegRows == 0);
|
||||
static_assert(kColsBC % kRegCols == 0);
|
||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
||||
const size_t kTilesY = kRowsAC / kRegRows;
|
||||
const size_t kTilesX = kColsBC / kRegCols;
|
||||
const size_t kTiles = kTilesX * kTilesY;
|
||||
|
||||
constexpr size_t kStrideA = kColsA_RowsB;
|
||||
constexpr size_t kStrideB = kColsA_RowsB;
|
||||
constexpr size_t kStrideC = kColsBC;
|
||||
|
||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
// Computes the finished product of one 4x4N tile and writes to C.
|
||||
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
});
|
||||
}
|
||||
|
||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||
// NOTE that batch_size is the number of rows of A and C.
|
||||
|
|
@ -869,35 +804,6 @@ HWY_NOINLINE void MatMul_4x4_Batch(
|
|||
});
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b,
|
||||
float* HWY_RESTRICT out) {
|
||||
for (size_t i = 0; i < kM; ++i) {
|
||||
for (size_t k = 0; k < kN; ++k) {
|
||||
for (size_t j = 0; j < kK; ++j) {
|
||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||
out[i * kK + j] += a1 * b1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA>
|
||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
||||
float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||
CompressTraits<SfpStream>::Decompress(d,
|
||||
/*in_capacity=*/0, b_sfp_stream, 0,
|
||||
b.get(), kK * kN);
|
||||
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||
// ops_test across instruction sets.
|
||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
|
||||
|
|
|
|||
|
|
@ -528,54 +528,6 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
|||
}
|
||||
}
|
||||
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledMatMul() {
|
||||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
|
||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||
|
||||
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
|
||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
|
||||
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||
MatMul_4x4<kM, kN, kK>(a->data(), b_trans->data(), c.get(), pool);
|
||||
|
||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||
}
|
||||
|
||||
void TestAllTiledMatMul() {
|
||||
// medium-sized square test
|
||||
TestTiledMatMul<512, 512, 512, float>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>();
|
||||
TestTiledMatMul<512, 512, 512, float, SfpStream>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// minimal non-square test
|
||||
TestTiledMatMul<4, 128, 4, float>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>();
|
||||
TestTiledMatMul<32, 128, 32, float, SfpStream>();
|
||||
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||
// Causes test timeout.
|
||||
// TestTiledMatMul<512, 24576, 3072, float>();
|
||||
}
|
||||
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledBatchMatMul() {
|
||||
|
|
@ -638,6 +590,11 @@ void TestAllTiledBatchMatMul() {
|
|||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||
// Causes test timeout.
|
||||
// TestTiledBatchMatMul<512, 24576, 3072, float>();
|
||||
}
|
||||
|
||||
void TestMatVecAdd() {
|
||||
|
|
@ -746,7 +703,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||
|
|
|
|||
Loading…
Reference in New Issue