static_assert shape constraints in MatMul 4x4

PiperOrigin-RevId: 639069345
This commit is contained in:
Paul Chang 2024-05-31 10:02:06 -07:00 committed by Copybara-Service
parent c616abe628
commit 5feacf120c
1 changed files with 8 additions and 8 deletions

View File

@ -198,9 +198,9 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
constexpr int RM = 4; // tile height constexpr int RM = 4; // tile height
constexpr int RN = 4; // tile width constexpr int RN = 4; // tile width
HWY_ASSERT(kM % RM == 0); static_assert(kM % RM == 0);
HWY_ASSERT(kColsA % N == 0); static_assert(kColsA % N == 0);
HWY_ASSERT(kColsA % RN == 0); static_assert(kColsA % RN == 0);
int lda = kColsA; int lda = kColsA;
int ldb = kColsA; // n instead of k because we're transposing int ldb = kColsA; // n instead of k because we're transposing
@ -242,11 +242,11 @@ HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A,
const int tiles = xtiles * ytiles; const int tiles = xtiles * ytiles;
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0 // 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
HWY_ASSERT(kM % RM == 0); static_assert(kM % RM == 0);
HWY_ASSERT(kColsA % N == 0); static_assert(kColsA % N == 0);
HWY_ASSERT(kColsA % RN == 0); static_assert(kColsA % RN == 0);
HWY_ASSERT(kK % RN == 0); static_assert(kK % RN == 0);
HWY_ASSERT(kColsA >= N); static_assert(kColsA >= N);
// Handles a single 4x4 chunk, which is completed and then written into C. // Handles a single 4x4 chunk, which is completed and then written into C.
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {