Fix for transpose matrix creation and additional tests

PiperOrigin-RevId: 641868053
This commit is contained in:
Phil Culliton 2024-06-10 05:23:25 -07:00 committed by Copybara-Service
parent 36e6915e18
commit c5bcb5438c
1 changed files with 18 additions and 10 deletions

View File

@ -414,8 +414,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
const float scale = 1.0f / kInner; const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
for (size_t j = 0; j < kInner; j++) { for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] = content[j * kOuter + i] =
static_cast<float>((j * kInner + i + offset) * scale); static_cast<float>((i * kInner + j + offset) * scale);
} }
}); });
@ -525,13 +525,10 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
} }
} }
template <typename MatTA, typename MatTB = MatTA> template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledMatMul() { void TestTiledMatMul() {
hwy::ThreadPool pool(3); hwy::ThreadPool pool(3);
constexpr size_t kM = 512; // 384
constexpr size_t kN = 512; // * 5; // 6; // 768
constexpr size_t kK = 512; // * 5; // 640
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a = std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool); GenerateMatHeap<MatTA, kM, kN>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b = std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
@ -550,9 +547,20 @@ void TestTiledMatMul() {
} }
void TestAllTiledMatMul() { void TestAllTiledMatMul() {
TestTiledMatMul<float>(); // medium-sized square test
TestTiledMatMul<hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, float>();
TestTiledMatMul<float, hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
// minimal non-square test
TestTiledMatMul<4, 128, 4, float>();
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices
// TestTiledMatMul<512, 24576, 3072, float>();
// TODO(janwas): SFP // TODO(janwas): SFP
} }