mirror of https://github.com/google/gemma.cpp.git
Fix for transpose matrix creation and additional tests
PiperOrigin-RevId: 641868053
This commit is contained in:
parent
36e6915e18
commit
c5bcb5438c
|
|
@ -414,8 +414,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
|||
const float scale = 1.0f / kInner;
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kInner; j++) {
|
||||
content[i * kInner + j] =
|
||||
static_cast<float>((j * kInner + i + offset) * scale);
|
||||
content[j * kOuter + i] =
|
||||
static_cast<float>((i * kInner + j + offset) * scale);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -525,13 +525,10 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledMatMul() {
|
||||
hwy::ThreadPool pool(3);
|
||||
constexpr size_t kM = 512; // 384
|
||||
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||
constexpr size_t kK = 512; // * 5; // 640
|
||||
|
||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||
|
|
@ -550,9 +547,20 @@ void TestTiledMatMul() {
|
|||
}
|
||||
|
||||
void TestAllTiledMatMul() {
|
||||
TestTiledMatMul<float>();
|
||||
TestTiledMatMul<hwy::bfloat16_t>();
|
||||
TestTiledMatMul<float, hwy::bfloat16_t>();
|
||||
// medium-sized square test
|
||||
TestTiledMatMul<512, 512, 512, float>();
|
||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||
|
||||
// minimal non-square test
|
||||
TestTiledMatMul<4, 128, 4, float>();
|
||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices
|
||||
// TestTiledMatMul<512, 24576, 3072, float>();
|
||||
|
||||
// TODO(janwas): SFP
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue