mirror of https://github.com/google/gemma.cpp.git
Add MMOptions as an argument to Matmul.
PiperOrigin-RevId: 802008198
This commit is contained in:
parent
229bd078a1
commit
0d2e74d74a
|
|
@ -1179,7 +1179,7 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C) {
|
||||
MatPtrT<TC>& C, MMOptions options) {
|
||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
@ -1205,12 +1205,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
MMPerKey& per_key = env.per_key[index];
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
|
||||
// Default to nested parallelism.
|
||||
const ParallelismType parallelism_type = ParallelismType::kNested;
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add);
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
|
||||
options.parallelism_type);
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
|
|
@ -1239,7 +1238,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||
|
|
|
|||
|
|
@ -67,8 +67,8 @@ enum class ParallelismType : uint8_t {
|
|||
};
|
||||
|
||||
struct MMOptions {
|
||||
ParallelismType parallelism_type_ = ParallelismType::kNested;
|
||||
uint8_t cluster_idx_ = 0;
|
||||
ParallelismType parallelism_type = ParallelismType::kNested;
|
||||
uint8_t cluster_idx = 0;
|
||||
};
|
||||
|
||||
struct MMSequentialPolicy {
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@
|
|||
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
|
||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||
MatPtrT<TC>& C) { \
|
||||
return MatMul(A, B, add, env, C); \
|
||||
MatPtrT<TC>& C, MMOptions options) { \
|
||||
return MatMul(A, B, add, env, C, options); \
|
||||
}
|
||||
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
|
||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||
MatPtrT<TC>& C);
|
||||
MatPtrT<TC>& C, MMOptions options);
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||
|
|
|
|||
|
|
@ -258,8 +258,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
|
||||
MatMulSlow(A, BT, add_row, env, C_slow);
|
||||
// A few reps to get coverage of the various autotuned code paths.
|
||||
MMOptions options;
|
||||
for (size_t rep = 0; rep < 16; ++rep) {
|
||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
|
||||
AssertClose(A, BT, C_slow, C, env, line);
|
||||
if (per_key->autotune.Best()) break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,9 +63,10 @@ namespace hn = hwy::HWY_NAMESPACE;
|
|||
template <typename TA, typename TC>
|
||||
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C) {
|
||||
return CallUpcasted(
|
||||
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
|
||||
MatPtrT<TC>& C, const MMOptions& options = MMOptions()) {
|
||||
return CallUpcasted(&B, [&](const auto* B_t) {
|
||||
return MatMulStatic(A, *B_t, add, env, C, options);
|
||||
});
|
||||
}
|
||||
|
||||
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue