mirror of https://github.com/google/gemma.cpp.git
Minor: ParallelismStrategy->Parallelism
PiperOrigin-RevId: 828936578
This commit is contained in:
parent
a344a70c59
commit
091b4567c9
|
|
@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
|||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
|
|
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
// tasks are very lightweight.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||
[&](size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
|
|||
{
|
||||
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||
// Better than kFlat.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
|
||||
ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
|
||||
}
|
||||
}
|
||||
|
|
@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
{
|
||||
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||
// very lightweight.
|
||||
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
|
||||
func);
|
||||
}
|
||||
|
|
@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
// Compress q to q_bf.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
|
||||
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t row, size_t worker) {
|
||||
CompressPerThread tls;
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ template <class Mat>
|
|||
void ActivationBatched(
|
||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||
Parallelism parallelism = Parallelism::kFlat) {
|
||||
using T = typename Mat::T;
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||
|
|
@ -115,7 +115,7 @@ template <class Mat1, class Mat2>
|
|||
HWY_NOINLINE void ActivationBatched(
|
||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||
Parallelism parallelism = Parallelism::kFlat) {
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
|
|
|
|||
|
|
@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config,
|
|||
timing_info.NotifyGenerated(non_eos.Count());
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
||||
Parallelism::kFlat, qbatch.Size(), env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kSampleAndStream,
|
||||
[&](size_t qi, size_t worker) {
|
||||
if (!non_eos.Get(qi)) return;
|
||||
|
|
|
|||
|
|
@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
|||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
const size_t cluster_idx = 0;
|
||||
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
|
||||
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
|
|
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
|||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
|
||||
TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
|
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||
? ParallelismStrategy::kHierarchical
|
||||
: ParallelismStrategy::kFlat;
|
||||
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
const Parallelism parallelism =
|
||||
HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
|
||||
ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
|
|
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
|
|||
const std::vector<IOBatch>& batches,
|
||||
ThreadingContext& ctx) {
|
||||
// >5x speedup from parallel reads when cached.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
|
||||
ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kReadBatches,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
|||
ThreadingContext& ctx, size_t cluster_idx) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
HWY_ASSERT(ranges.size() == blobs.size());
|
||||
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
|
||||
ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
|
||||
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||
reader.file().Read(ranges[i].offset, ranges[i].bytes,
|
||||
|
|
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
|
|||
const double t0 = hwy::platform::Now();
|
||||
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
|
||||
ctx.pools.NumClusters());
|
||||
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||
ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||
[&](const size_t task, size_t cluster_idx) {
|
||||
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
|
||||
task ? blobs1 : blobs2, ctx, cluster_idx);
|
||||
|
|
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
const double t0 = hwy::platform::Now();
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
|
||||
ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
|
||||
Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
|
|
|
|||
|
|
@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
|
|||
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
|
||||
static_cast<const uint8_t*>(data), writes);
|
||||
|
||||
const ParallelismStrategy strategy = file_->IsAppendOnly()
|
||||
? ParallelismStrategy::kNone
|
||||
: ParallelismStrategy::kFlat;
|
||||
const Parallelism parallelism =
|
||||
file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
|
||||
ParallelFor(
|
||||
strategy, writes.size(), ctx_,
|
||||
parallelism, writes.size(), ctx_,
|
||||
/*cluster_idx=*/0, Callers::kBlobWriter,
|
||||
[this, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange& range = writes[i].range;
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||
Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
|
||||
std::to_string(i).c_str());
|
||||
|
|
|
|||
|
|
@ -1126,7 +1126,7 @@ void TestAllDot() {
|
|||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||
Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||
[&](size_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Row(thread);
|
||||
float* HWY_RESTRICT pb = b.Row(thread);
|
||||
|
|
|
|||
12
ops/matmul.h
12
ops/matmul.h
|
|
@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
|||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||
|
||||
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
||||
// Policy classes for parallelism, implementing some of `Parallelism`.
|
||||
|
||||
struct MMParallelNone {
|
||||
template <class Func>
|
||||
|
|
@ -220,14 +220,14 @@ struct MMParallelHierarchical {
|
|||
};
|
||||
|
||||
template <class Func, typename... Args>
|
||||
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
|
||||
void DispatchParallelism(Parallelism parallelism, const Func& func,
|
||||
Args&&... args) {
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone:
|
||||
case Parallelism::kNone:
|
||||
return func(MMParallelNone(), std::forward<Args>(args)...);
|
||||
case ParallelismStrategy::kWithinCluster:
|
||||
case Parallelism::kWithinCluster:
|
||||
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
|
||||
case ParallelismStrategy::kHierarchical:
|
||||
case Parallelism::kHierarchical:
|
||||
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
|
|
@ -716,7 +716,7 @@ class MMOptions {
|
|||
const void* opaque = nullptr;
|
||||
|
||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||
Parallelism parallelism = Parallelism::kHierarchical;
|
||||
};
|
||||
|
||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
activations.DebugCheckSameShape(out);
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||
ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
|
||||
cluster_idx, Callers::kOpsRMSNormBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||
|
|
@ -517,7 +517,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
|||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsRMSNormInplaceBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||
|
|
@ -550,7 +550,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
|||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
|
||||
});
|
||||
|
|
@ -1290,7 +1290,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
|||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||
if (cap == 0.0f) return;
|
||||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsMaybeLogitsSoftCapBatched,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ struct ThreadingContext {
|
|||
|
||||
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
||||
// the current thread is running on one thread per cluster, which happens
|
||||
// when `ParallelismStrategy` is `kAcrossClusters`.
|
||||
// when `Parallelism` is `kAcrossClusters`.
|
||||
size_t Worker(size_t cluster_idx) const {
|
||||
return cluster_idx * pools.MaxWorkersPerCluster();
|
||||
}
|
||||
|
|
@ -130,7 +130,7 @@ struct ThreadingContext {
|
|||
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
|
||||
|
||||
// Describes the strategy for distributing parallel work across cores.
|
||||
enum class ParallelismStrategy : uint8_t {
|
||||
enum class Parallelism : uint8_t {
|
||||
// Execute using a single-threaded loop on the calling thread. The `worker`
|
||||
// index passed to the user's `Func` is unique across clusters.
|
||||
kNone,
|
||||
|
|
@ -245,19 +245,19 @@ void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
|
|||
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
|
||||
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
|
||||
template <class Func>
|
||||
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||
void ParallelFor(Parallelism parallelism, size_t num_tasks,
|
||||
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
|
||||
if (cluster_idx != 0) {
|
||||
// If already running across clusters, only use within-cluster modes.
|
||||
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
|
||||
parallelism == ParallelismStrategy::kWithinCluster);
|
||||
HWY_DASSERT(parallelism == Parallelism::kNone ||
|
||||
parallelism == Parallelism::kWithinCluster);
|
||||
}
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone: {
|
||||
case Parallelism::kNone: {
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
for (size_t task = 0; task < num_tasks; ++task) {
|
||||
func(task, worker);
|
||||
|
|
@ -265,16 +265,16 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
return;
|
||||
}
|
||||
|
||||
case ParallelismStrategy::kAcrossClusters:
|
||||
case Parallelism::kAcrossClusters:
|
||||
return ParallelForAcrossClusters(
|
||||
num_tasks, ctx, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||
|
||||
case ParallelismStrategy::kWithinCluster:
|
||||
case Parallelism::kWithinCluster:
|
||||
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||
func);
|
||||
|
||||
case ParallelismStrategy::kFlat:
|
||||
case Parallelism::kFlat:
|
||||
// Choose a single pool: the only cluster, or across all clusters
|
||||
// (slower synchronization, but more memory bandwidth)
|
||||
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
|
||||
|
|
@ -286,7 +286,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
func(task, ctx.Worker(cluster_idx));
|
||||
});
|
||||
|
||||
case ParallelismStrategy::kHierarchical:
|
||||
case Parallelism::kHierarchical:
|
||||
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue