mirror of https://github.com/google/gemma.cpp.git
Removed now redundant non-batch matmul
PiperOrigin-RevId: 643317187
This commit is contained in:
parent
b17631c95f
commit
198326a682
94
gemma/ops.h
94
gemma/ops.h
|
|
@ -753,71 +753,6 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
|
c20, c21, c22, c23, c30, c31, c32, c33, tile_c,
|
||||||
stride_c);
|
stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
|
||||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
|
||||||
// This function loops over all tiles (static scheduling). TODO(janwas): we can
|
|
||||||
// possibly remove this if ThreadPool(0) is as efficient as the loop.
|
|
||||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatT>
|
|
||||||
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
|
||||||
MatT* HWY_RESTRICT C) {
|
|
||||||
const hn::ScalableTag<MatT> d;
|
|
||||||
const size_t N = hn::Lanes(d); // column step size
|
|
||||||
constexpr size_t kRegRows = 4;
|
|
||||||
constexpr size_t kRegCols = 4; // in vectors
|
|
||||||
|
|
||||||
static_assert(kRowsAC % kRegRows == 0);
|
|
||||||
static_assert(kColsBC % kRegCols == 0);
|
|
||||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
|
||||||
constexpr size_t kTilesY = kRowsAC / kRegRows;
|
|
||||||
constexpr size_t kTilesX = kColsBC / kRegCols;
|
|
||||||
constexpr size_t kTiles = kTilesX * kTilesY;
|
|
||||||
|
|
||||||
constexpr size_t kStrideA = kColsA_RowsB;
|
|
||||||
constexpr size_t kStrideB = kColsA_RowsB; // B is column-major
|
|
||||||
constexpr size_t kStrideC = kColsBC;
|
|
||||||
|
|
||||||
HWY_UNROLL(1)
|
|
||||||
for (size_t idx_tile = 0; idx_tile < kTiles; ++idx_tile) {
|
|
||||||
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tiled 4x4 GEMM. Typically kRowsAC is 4..512, kColsA_RowsB is 3k or 24k, and
|
|
||||||
// kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
|
||||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
|
||||||
template <size_t kRowsAC, size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
|
||||||
typename MatTB, typename OutT>
|
|
||||||
HWY_NOINLINE void MatMul_4x4(const MatTA* HWY_RESTRICT A,
|
|
||||||
const MatTB* HWY_RESTRICT B, OutT* HWY_RESTRICT C,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
|
||||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
|
||||||
// so that we finish one L3-sized piece of C at a time.
|
|
||||||
const hn::ScalableTag<MatTA> d;
|
|
||||||
const size_t N = Lanes(d);
|
|
||||||
constexpr size_t kRegRows = 4;
|
|
||||||
constexpr size_t kRegCols = 4; // in vectors
|
|
||||||
|
|
||||||
static_assert(kRowsAC % kRegRows == 0);
|
|
||||||
static_assert(kColsBC % kRegCols == 0);
|
|
||||||
HWY_ASSERT(kColsA_RowsB % (N * kRegCols) == 0);
|
|
||||||
const size_t kTilesY = kRowsAC / kRegRows;
|
|
||||||
const size_t kTilesX = kColsBC / kRegCols;
|
|
||||||
const size_t kTiles = kTilesX * kTilesY;
|
|
||||||
|
|
||||||
constexpr size_t kStrideA = kColsA_RowsB;
|
|
||||||
constexpr size_t kStrideB = kColsA_RowsB;
|
|
||||||
constexpr size_t kStrideC = kColsBC;
|
|
||||||
|
|
||||||
pool.Run(0, kTiles, [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
|
||||||
// Computes the finished product of one 4x4N tile and writes to C.
|
|
||||||
GEMM_4x4_Tile<kRegRows, kColsA_RowsB>(
|
|
||||||
A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||||
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
||||||
// NOTE that batch_size is the number of rows of A and C.
|
// NOTE that batch_size is the number of rows of A and C.
|
||||||
|
|
@ -869,35 +804,6 @@ HWY_NOINLINE void MatMul_4x4_Batch(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
|
||||||
// ops_test across instruction sets.
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
|
|
||||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
|
||||||
const MatTB* HWY_RESTRICT b,
|
|
||||||
float* HWY_RESTRICT out) {
|
|
||||||
for (size_t i = 0; i < kM; ++i) {
|
|
||||||
for (size_t k = 0; k < kN; ++k) {
|
|
||||||
for (size_t j = 0; j < kK; ++j) {
|
|
||||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
|
||||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
|
||||||
out[i * kK + j] += a1 * b1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA>
|
|
||||||
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
|
||||||
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
|
||||||
float* HWY_RESTRICT out) {
|
|
||||||
const hn::ScalableTag<float> d;
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
|
||||||
CompressTraits<SfpStream>::Decompress(d,
|
|
||||||
/*in_capacity=*/0, b_sfp_stream, 0,
|
|
||||||
b.get(), kK * kN);
|
|
||||||
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||||
// ops_test across instruction sets.
|
// ops_test across instruction sets.
|
||||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
|
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
|
||||||
|
|
|
||||||
|
|
@ -528,54 +528,6 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
|
||||||
typename MatTB = MatTA>
|
|
||||||
void TestTiledMatMul() {
|
|
||||||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
|
||||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
|
||||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
|
||||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow_batch =
|
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
|
||||||
|
|
||||||
MatMulSlow<kM, kN, kK>(a->data(), b->data(), c_slow->data());
|
|
||||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), c_slow_batch->data());
|
|
||||||
AssertClose(c_slow->data(), c_slow_batch->data(), kM * kK);
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
|
||||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
|
||||||
MatMul_4x4<kM, kN, kK>(a->data(), b_trans->data(), c.get(), pool);
|
|
||||||
|
|
||||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllTiledMatMul() {
|
|
||||||
// medium-sized square test
|
|
||||||
TestTiledMatMul<512, 512, 512, float>();
|
|
||||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
|
||||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
|
||||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, float>();
|
|
||||||
TestTiledMatMul<512, 512, 512, float, SfpStream>();
|
|
||||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>();
|
|
||||||
|
|
||||||
// minimal non-square test
|
|
||||||
TestTiledMatMul<4, 128, 4, float>();
|
|
||||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
|
||||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
|
||||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t, float>();
|
|
||||||
TestTiledMatMul<32, 128, 32, float, SfpStream>();
|
|
||||||
TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>();
|
|
||||||
|
|
||||||
// large-scale test
|
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
|
||||||
// Causes test timeout.
|
|
||||||
// TestTiledMatMul<512, 24576, 3072, float>();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||||
typename MatTB = MatTA>
|
typename MatTB = MatTA>
|
||||||
void TestTiledBatchMatMul() {
|
void TestTiledBatchMatMul() {
|
||||||
|
|
@ -638,6 +590,11 @@ void TestAllTiledBatchMatMul() {
|
||||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
TestTiledBatchMatMul<1, 128, 32, float, SfpStream>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>();
|
||||||
|
|
||||||
|
// large-scale test
|
||||||
|
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||||
|
// Causes test timeout.
|
||||||
|
// TestTiledBatchMatMul<512, 24576, 3072, float>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
|
|
@ -746,7 +703,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledMatMul);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue