Add MMOptions as an argument to Matmul.

PiperOrigin-RevId: 802008198
This commit is contained in:
Marie White 2025-09-01 23:46:07 -07:00 committed by Copybara-Service
parent 229bd078a1
commit 0d2e74d74a
6 changed files with 15 additions and 14 deletions

View File

@ -1179,7 +1179,7 @@ struct MMImpl {
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
MatPtrT<TC>& C, MMOptions options) {
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
const Allocator& allocator = env.ctx.allocator;
@ -1205,12 +1205,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;
// Default to nested parallelism.
const ParallelismType parallelism_type = ParallelismType::kNested;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add);
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
options.parallelism_type);
return &per_key;
}
@ -1239,7 +1238,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /

View File

@ -67,8 +67,8 @@ enum class ParallelismType : uint8_t {
};
struct MMOptions {
ParallelismType parallelism_type_ = ParallelismType::kNested;
uint8_t cluster_idx_ = 0;
ParallelismType parallelism_type = ParallelismType::kNested;
uint8_t cluster_idx = 0;
};
struct MMSequentialPolicy {

View File

@ -28,8 +28,8 @@
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
MatPtrT<TC>& C) { \
return MatMul(A, B, add, env, C); \
MatPtrT<TC>& C, MMOptions options) { \
return MatMul(A, B, add, env, C, options); \
}
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \

View File

@ -35,7 +35,7 @@
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
MatPtrT<TC>& C);
MatPtrT<TC>& C, MMOptions options);
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \

View File

@ -258,8 +258,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulSlow(A, BT, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
MMOptions options;
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
AssertClose(A, BT, C_slow, C, env, line);
if (per_key->autotune.Best()) break;
}

View File

@ -63,9 +63,10 @@ namespace hn = hwy::HWY_NAMESPACE;
template <typename TA, typename TC>
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
return CallUpcasted(
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
MatPtrT<TC>& C, const MMOptions& options = MMOptions()) {
return CallUpcasted(&B, [&](const auto* B_t) {
return MatMulStatic(A, *B_t, add, env, C, options);
});
}
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {