diff --git a/gemma/ops.h b/gemma/ops.h index 1ea800b..6401cc4 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -198,9 +198,9 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, constexpr int RM = 4; // tile height constexpr int RN = 4; // tile width - HWY_ASSERT(kM % RM == 0); - HWY_ASSERT(kColsA % N == 0); - HWY_ASSERT(kColsA % RN == 0); + static_assert(kM % RM == 0); + static_assert(kColsA % N == 0); + static_assert(kColsA % RN == 0); int lda = kColsA; int ldb = kColsA; // n instead of k because we're transposing @@ -242,11 +242,11 @@ HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A, const int tiles = xtiles * ytiles; // 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0 - HWY_ASSERT(kM % RM == 0); - HWY_ASSERT(kColsA % N == 0); - HWY_ASSERT(kColsA % RN == 0); - HWY_ASSERT(kK % RN == 0); - HWY_ASSERT(kColsA >= N); + static_assert(kM % RM == 0); + static_assert(kColsA % N == 0); + static_assert(kColsA % RN == 0); + static_assert(kK % RN == 0); + static_assert(kColsA >= N); // Handles a single 4x4 chunk, which is completed and then written into C. pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {