mirror of https://github.com/google/gemma.cpp.git
Add non-threading parallel policy.
PiperOrigin-RevId: 800913294
This commit is contained in:
parent
00b70f69c5
commit
bc0c0bac8b
42
ops/matmul.h
42
ops/matmul.h
|
|
@ -71,6 +71,48 @@ struct MMOptions {
|
||||||
uint8_t cluster_idx_ = 0;
|
uint8_t cluster_idx_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MMSequentialPolicy {
|
||||||
|
template <class Func>
|
||||||
|
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||||
|
const Func& func) {
|
||||||
|
func(/*pkg_idx=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
|
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||||
|
const Func& func) {
|
||||||
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
func(range_np, base_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
|
const IndexRangePartition& ranges_mc,
|
||||||
|
const IndexRangePartition& ranges_nc,
|
||||||
|
size_t pkg_idx, const Func& func) {
|
||||||
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||||
|
const IndexRange range_mc = ranges_mc.Range(i);
|
||||||
|
for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) {
|
||||||
|
const IndexRange range_nc = ranges_nc.Range(j);
|
||||||
|
func(range_mc, range_nc, base_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
|
size_t pkg_idx, const Func& func) {
|
||||||
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||||
|
func(row_a, base_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MMNestedParallelPolicy {
|
struct MMNestedParallelPolicy {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue