mirror of https://github.com/google/gemma.cpp.git
Add MMOptions as an argument to Matmul.
PiperOrigin-RevId: 802008198
This commit is contained in:
parent
229bd078a1
commit
0d2e74d74a
|
|
@ -1179,7 +1179,7 @@ struct MMImpl {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C) {
|
MatPtrT<TC>& C, MMOptions options) {
|
||||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
||||||
|
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
|
|
@ -1205,12 +1205,11 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
||||||
// Default to nested parallelism.
|
|
||||||
const ParallelismType parallelism_type = ParallelismType::kNested;
|
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
add);
|
add);
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
|
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
|
||||||
|
options.parallelism_type);
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1239,7 +1238,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
|
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,8 @@ enum class ParallelismType : uint8_t {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MMOptions {
|
struct MMOptions {
|
||||||
ParallelismType parallelism_type_ = ParallelismType::kNested;
|
ParallelismType parallelism_type = ParallelismType::kNested;
|
||||||
uint8_t cluster_idx_ = 0;
|
uint8_t cluster_idx = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MMSequentialPolicy {
|
struct MMSequentialPolicy {
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,8 @@
|
||||||
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
|
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
|
||||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||||
MatPtrT<TC>& C) { \
|
MatPtrT<TC>& C, MMOptions options) { \
|
||||||
return MatMul(A, B, add, env, C); \
|
return MatMul(A, B, add, env, C, options); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
|
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@
|
||||||
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
|
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
|
||||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||||
MatPtrT<TC>& C);
|
MatPtrT<TC>& C, MMOptions options);
|
||||||
|
|
||||||
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||||
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||||
|
|
|
||||||
|
|
@ -258,8 +258,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
|
|
||||||
MatMulSlow(A, BT, add_row, env, C_slow);
|
MatMulSlow(A, BT, add_row, env, C_slow);
|
||||||
// A few reps to get coverage of the various autotuned code paths.
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
|
MMOptions options;
|
||||||
for (size_t rep = 0; rep < 16; ++rep) {
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
|
||||||
AssertClose(A, BT, C_slow, C, env, line);
|
AssertClose(A, BT, C_slow, C, env, line);
|
||||||
if (per_key->autotune.Best()) break;
|
if (per_key->autotune.Best()) break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,9 +63,10 @@ namespace hn = hwy::HWY_NAMESPACE;
|
||||||
template <typename TA, typename TC>
|
template <typename TA, typename TC>
|
||||||
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C) {
|
MatPtrT<TC>& C, const MMOptions& options = MMOptions()) {
|
||||||
return CallUpcasted(
|
return CallUpcasted(&B, [&](const auto* B_t) {
|
||||||
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
|
return MatMulStatic(A, *B_t, add, env, C, options);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue