diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8f91114..95277db 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1179,7 +1179,7 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C) { + MatPtrT& C, MMOptions options) { RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; @@ -1205,12 +1205,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; - // Default to nested parallelism. - const ParallelismType parallelism_type = ParallelismType::kNested; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), + options.parallelism_type); return &per_key; } @@ -1239,7 +1238,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.h b/ops/matmul.h index dda9673..620d382 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -67,8 +67,8 @@ enum class ParallelismType : uint8_t { }; struct MMOptions { - ParallelismType parallelism_type_ = ParallelismType::kNested; - uint8_t cluster_idx_ = 0; + ParallelismType parallelism_type = ParallelismType::kNested; + uint8_t cluster_idx = 0; }; struct MMSequentialPolicy { diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h index 28b21cf..ba09e0c 100644 --- a/ops/matmul_static-inl.h +++ b/ops/matmul_static-inl.h @@ -28,8 +28,8 @@ #define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - MatPtrT& C) { \ - return MatMul(A, B, add, env, C); \ + MatPtrT& C, MMOptions options) { \ + return MatMul(A, B, add, env, C, options); \ } #if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \ diff --git a/ops/matmul_static.h b/ops/matmul_static.h index c06b87a..61dc505 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -35,7 +35,7 @@ #define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - MatPtrT& C); + MatPtrT& C, MMOptions options); // Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. #define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 3a8528a..dc6f559 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -258,8 +258,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulSlow(A, BT, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. + MMOptions options; for (size_t rep = 0; rep < 16; ++rep) { - MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C); + MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options); AssertClose(A, BT, C_slow, C, env, line); if (per_key->autotune.Best()) break; } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index caf8041..0438bf7 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -63,9 +63,10 @@ namespace hn = hwy::HWY_NAMESPACE; template MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C) { - return CallUpcasted( - &B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); }); + MatPtrT& C, const MMOptions& options = MMOptions()) { + return CallUpcasted(&B, [&](const auto* B_t) { + return MatMulStatic(A, *B_t, add, env, C, options); + }); } HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {