mirror of https://github.com/google/gemma.cpp.git
Create separate MMStorage objects per cluster.
PiperOrigin-RevId: 802588625
This commit is contained in:
parent
b7b3d353db
commit
74ffe079c4
|
|
@ -599,7 +599,8 @@ class MMPerPackage {
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16.
|
// support a separate F32 kernel; most A are already BF16.
|
||||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
const StridedViewBF A_view =
|
||||||
|
args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents());
|
||||||
DecompressA<MMParallelPolicyT>(A, A_view);
|
DecompressA<MMParallelPolicyT>(A, A_view);
|
||||||
DispatchOrder(parallel_policy, A_view, B, C_rows);
|
DispatchOrder(parallel_policy, A_view, B, C_rows);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -402,7 +402,19 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
|
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
|
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
|
||||||
|
// Create storage per cluster. This only applies to in-cluster parallelism.
|
||||||
|
// For nested and sequential parallelism, a single MMStorage is used.
|
||||||
|
size_t num_packages = ctx.topology.NumPackages();
|
||||||
|
size_t num_clusters = 0;
|
||||||
|
for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) {
|
||||||
|
num_clusters += ctx.topology.NumClusters(pkg_idx);
|
||||||
|
}
|
||||||
|
storage.reserve(num_clusters);
|
||||||
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
|
storage.push_back(MMStorage(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -369,6 +369,7 @@ class MMStorage {
|
||||||
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
|
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||||
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||||
HWY_DASSERT(extents.cols <= kMaxK);
|
HWY_DASSERT(extents.cols <= kMaxK);
|
||||||
|
HWY_DASSERT(pkg_A_[pkg_idx] != nullptr);
|
||||||
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
||||||
extents.cols, pkg_A_[pkg_idx]->Stride());
|
extents.cols, pkg_A_[pkg_idx]->Stride());
|
||||||
}
|
}
|
||||||
|
|
@ -733,7 +734,7 @@ struct MatMulEnv {
|
||||||
// Whether to print the best config immediately after autotuning finished.
|
// Whether to print the best config immediately after autotuning finished.
|
||||||
bool print_best = false;
|
bool print_best = false;
|
||||||
|
|
||||||
MMStorage storage;
|
std::vector<MMStorage> storage;
|
||||||
MMKeys keys;
|
MMKeys keys;
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue