mirror of https://github.com/google/gemma.cpp.git
Fix for transpose matrix creation and additional tests
PiperOrigin-RevId: 641868053
This commit is contained in:
parent
36e6915e18
commit
c5bcb5438c
|
|
@ -414,8 +414,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||||
const float scale = 1.0f / kInner;
|
const float scale = 1.0f / kInner;
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||||
for (size_t j = 0; j < kInner; j++) {
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
content[i * kInner + j] =
|
content[j * kOuter + i] =
|
||||||
static_cast<float>((j * kInner + i + offset) * scale);
|
static_cast<float>((i * kInner + j + offset) * scale);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -525,13 +525,10 @@ void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename MatTA, typename MatTB = MatTA>
|
template <size_t kM, size_t kN, size_t kK, typename MatTA,
|
||||||
|
typename MatTB = MatTA>
|
||||||
void TestTiledMatMul() {
|
void TestTiledMatMul() {
|
||||||
hwy::ThreadPool pool(3);
|
hwy::ThreadPool pool(3);
|
||||||
constexpr size_t kM = 512; // 384
|
|
||||||
constexpr size_t kN = 512; // * 5; // 6; // 768
|
|
||||||
constexpr size_t kK = 512; // * 5; // 640
|
|
||||||
|
|
||||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b =
|
||||||
|
|
@ -550,9 +547,20 @@ void TestTiledMatMul() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllTiledMatMul() {
|
void TestAllTiledMatMul() {
|
||||||
TestTiledMatMul<float>();
|
// medium-sized square test
|
||||||
TestTiledMatMul<hwy::bfloat16_t>();
|
TestTiledMatMul<512, 512, 512, float>();
|
||||||
TestTiledMatMul<float, hwy::bfloat16_t>();
|
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||||
|
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||||
|
|
||||||
|
// minimal non-square test
|
||||||
|
TestTiledMatMul<4, 128, 4, float>();
|
||||||
|
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||||
|
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
|
||||||
|
// large-scale test
|
||||||
|
// TODO(philculliton): investigate rounding issues with large matrices
|
||||||
|
// TestTiledMatMul<512, 24576, 3072, float>();
|
||||||
|
|
||||||
// TODO(janwas): SFP
|
// TODO(janwas): SFP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue