mirror of https://github.com/google/gemma.cpp.git
static_assert shape constraints in MatMul 4x4
PiperOrigin-RevId: 639069345
This commit is contained in:
parent
c616abe628
commit
5feacf120c
16
gemma/ops.h
16
gemma/ops.h
|
|
@ -198,9 +198,9 @@ void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
||||||
constexpr int RM = 4; // tile height
|
constexpr int RM = 4; // tile height
|
||||||
constexpr int RN = 4; // tile width
|
constexpr int RN = 4; // tile width
|
||||||
|
|
||||||
HWY_ASSERT(kM % RM == 0);
|
static_assert(kM % RM == 0);
|
||||||
HWY_ASSERT(kColsA % N == 0);
|
static_assert(kColsA % N == 0);
|
||||||
HWY_ASSERT(kColsA % RN == 0);
|
static_assert(kColsA % RN == 0);
|
||||||
|
|
||||||
int lda = kColsA;
|
int lda = kColsA;
|
||||||
int ldb = kColsA; // n instead of k because we're transposing
|
int ldb = kColsA; // n instead of k because we're transposing
|
||||||
|
|
@ -242,11 +242,11 @@ HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A,
|
||||||
const int tiles = xtiles * ytiles;
|
const int tiles = xtiles * ytiles;
|
||||||
|
|
||||||
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
|
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
|
||||||
HWY_ASSERT(kM % RM == 0);
|
static_assert(kM % RM == 0);
|
||||||
HWY_ASSERT(kColsA % N == 0);
|
static_assert(kColsA % N == 0);
|
||||||
HWY_ASSERT(kColsA % RN == 0);
|
static_assert(kColsA % RN == 0);
|
||||||
HWY_ASSERT(kK % RN == 0);
|
static_assert(kK % RN == 0);
|
||||||
HWY_ASSERT(kColsA >= N);
|
static_assert(kColsA >= N);
|
||||||
|
|
||||||
// Handles a single 4x4 chunk, which is completed and then written into C.
|
// Handles a single 4x4 chunk, which is completed and then written into C.
|
||||||
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue