mirror of https://github.com/google/gemma.cpp.git
static_assert shape constraints in MatMul 4x4
PiperOrigin-RevId: 639069345
This commit is contained in:
parent
c616abe628
commit
5feacf120c
16
gemma/ops.h
16
gemma/ops.h
|
|
@ -198,9 +198,9 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
|||
constexpr int RM = 4; // tile height
|
||||
constexpr int RN = 4; // tile width
|
||||
|
||||
HWY_ASSERT(kM % RM == 0);
|
||||
HWY_ASSERT(kColsA % N == 0);
|
||||
HWY_ASSERT(kColsA % RN == 0);
|
||||
static_assert(kM % RM == 0);
|
||||
static_assert(kColsA % N == 0);
|
||||
static_assert(kColsA % RN == 0);
|
||||
|
||||
int lda = kColsA;
|
||||
int ldb = kColsA; // n instead of k because we're transposing
|
||||
|
|
@ -242,11 +242,11 @@ HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A,
|
|||
const int tiles = xtiles * ytiles;
|
||||
|
||||
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
|
||||
HWY_ASSERT(kM % RM == 0);
|
||||
HWY_ASSERT(kColsA % N == 0);
|
||||
HWY_ASSERT(kColsA % RN == 0);
|
||||
HWY_ASSERT(kK % RN == 0);
|
||||
HWY_ASSERT(kColsA >= N);
|
||||
static_assert(kM % RM == 0);
|
||||
static_assert(kColsA % N == 0);
|
||||
static_assert(kColsA % RN == 0);
|
||||
static_assert(kK % RN == 0);
|
||||
static_assert(kColsA >= N);
|
||||
|
||||
// Handles a single 4x4 chunk, which is completed and then written into C.
|
||||
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
|
|||
Loading…
Reference in New Issue