mirror of https://github.com/google/gemma.cpp.git
Minor: ParallelismStrategy->Parallelism
PiperOrigin-RevId: 828936578
This commit is contained in:
parent
a344a70c59
commit
091b4567c9
|
|
@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
||||||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||||
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
|
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
|
||||||
const float scale = SfpStream::kMax / extents.Area();
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||||
Callers::kTest, [&](size_t r, size_t thread) {
|
Callers::kTest, [&](size_t r, size_t thread) {
|
||||||
float* HWY_RESTRICT row = raw.Row(r);
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
|
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||||
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
|
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
|
||||||
const float scale = SfpStream::kMax / extents.Area();
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||||
Callers::kTest, [&](size_t r, size_t thread) {
|
Callers::kTest, [&](size_t r, size_t thread) {
|
||||||
float* HWY_RESTRICT row = raw.Row(r);
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||||
// tasks are very lightweight.
|
// tasks are very lightweight.
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
|
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||||
[&](size_t task, size_t worker) HWY_ATTR {
|
[&](size_t task, size_t worker) HWY_ATTR {
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
|
||||||
{
|
{
|
||||||
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||||
// Better than kFlat.
|
// Better than kFlat.
|
||||||
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
|
ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
|
||||||
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
|
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
{
|
{
|
||||||
// kHierarchical is not worth the extra sync overhead because the tasks are
|
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||||
// very lightweight.
|
// very lightweight.
|
||||||
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||||
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
|
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
|
||||||
func);
|
func);
|
||||||
}
|
}
|
||||||
|
|
@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
// Compress q to q_bf.
|
// Compress q to q_bf.
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
|
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
|
||||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||||
[&](size_t row, size_t worker) {
|
[&](size_t row, size_t worker) {
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ template <class Mat>
|
||||||
void ActivationBatched(
|
void ActivationBatched(
|
||||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||||
size_t cluster_idx = 0,
|
size_t cluster_idx = 0,
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
Parallelism parallelism = Parallelism::kFlat) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||||
|
|
@ -115,7 +115,7 @@ template <class Mat1, class Mat2>
|
||||||
HWY_NOINLINE void ActivationBatched(
|
HWY_NOINLINE void ActivationBatched(
|
||||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||||
size_t cluster_idx = 0,
|
size_t cluster_idx = 0,
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
Parallelism parallelism = Parallelism::kFlat) {
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||||
|
|
|
||||||
|
|
@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config,
|
||||||
timing_info.NotifyGenerated(non_eos.Count());
|
timing_info.NotifyGenerated(non_eos.Count());
|
||||||
|
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
Parallelism::kFlat, qbatch.Size(), env.ctx,
|
||||||
/*cluster_idx=*/0, Callers::kSampleAndStream,
|
/*cluster_idx=*/0, Callers::kSampleAndStream,
|
||||||
[&](size_t qi, size_t worker) {
|
[&](size_t qi, size_t worker) {
|
||||||
if (!non_eos.Get(qi)) return;
|
if (!non_eos.Get(qi)) return;
|
||||||
|
|
|
||||||
|
|
@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
||||||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
const size_t cluster_idx = 0;
|
const size_t cluster_idx = 0;
|
||||||
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
|
ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||||
GetLayer(layer)->Fixup(mat_owners, ctx);
|
GetLayer(layer)->Fixup(mat_owners, ctx);
|
||||||
});
|
});
|
||||||
|
|
||||||
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
|
ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||||
VitLayer(layer)->Fixup(mat_owners, ctx);
|
VitLayer(layer)->Fixup(mat_owners, ctx);
|
||||||
});
|
});
|
||||||
|
|
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
|
|
||||||
// Allocate in parallel because faulting in large tensors is slow.
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||||
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
|
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
|
||||||
TensorToRead& tensor = tensors[task];
|
TensorToRead& tensor = tensors[task];
|
||||||
MatPtr& mat = *tensor.mat;
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
|
||||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
const BlobReader& reader, ThreadingContext& ctx) {
|
const BlobReader& reader, ThreadingContext& ctx) {
|
||||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
const Parallelism parallelism =
|
||||||
? ParallelismStrategy::kHierarchical
|
HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
|
||||||
: ParallelismStrategy::kFlat;
|
ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||||
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
|
||||||
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
|
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
|
||||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
|
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
|
||||||
const TensorToRead& tensor = tensors[task];
|
const TensorToRead& tensor = tensors[task];
|
||||||
|
|
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
|
||||||
const std::vector<IOBatch>& batches,
|
const std::vector<IOBatch>& batches,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
// >5x speedup from parallel reads when cached.
|
// >5x speedup from parallel reads when cached.
|
||||||
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
|
ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
|
||||||
/*cluster_idx=*/0, Callers::kReadBatches,
|
/*cluster_idx=*/0, Callers::kReadBatches,
|
||||||
[&](uint64_t task, size_t thread) {
|
[&](uint64_t task, size_t thread) {
|
||||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
|
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
||||||
ThreadingContext& ctx, size_t cluster_idx) {
|
ThreadingContext& ctx, size_t cluster_idx) {
|
||||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||||
HWY_ASSERT(ranges.size() == blobs.size());
|
HWY_ASSERT(ranges.size() == blobs.size());
|
||||||
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
|
ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
|
||||||
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||||
reader.file().Read(ranges[i].offset, ranges[i].bytes,
|
reader.file().Read(ranges[i].offset, ranges[i].bytes,
|
||||||
|
|
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
|
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
|
||||||
ctx.pools.NumClusters());
|
ctx.pools.NumClusters());
|
||||||
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||||
[&](const size_t task, size_t cluster_idx) {
|
[&](const size_t task, size_t cluster_idx) {
|
||||||
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
|
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
|
||||||
task ? blobs1 : blobs2, ctx, cluster_idx);
|
task ? blobs1 : blobs2, ctx, cluster_idx);
|
||||||
|
|
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
std::atomic<size_t> blobs_equal{};
|
std::atomic<size_t> blobs_equal{};
|
||||||
std::atomic<size_t> blobs_diff{};
|
std::atomic<size_t> blobs_diff{};
|
||||||
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
|
ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
|
||||||
Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||||
const size_t mismatches =
|
const size_t mismatches =
|
||||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||||
|
|
|
||||||
|
|
@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
|
||||||
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
|
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
|
||||||
static_cast<const uint8_t*>(data), writes);
|
static_cast<const uint8_t*>(data), writes);
|
||||||
|
|
||||||
const ParallelismStrategy strategy = file_->IsAppendOnly()
|
const Parallelism parallelism =
|
||||||
? ParallelismStrategy::kNone
|
file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
|
||||||
: ParallelismStrategy::kFlat;
|
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
strategy, writes.size(), ctx_,
|
parallelism, writes.size(), ctx_,
|
||||||
/*cluster_idx=*/0, Callers::kBlobWriter,
|
/*cluster_idx=*/0, Callers::kBlobWriter,
|
||||||
[this, &writes](uint64_t i, size_t /*thread*/) {
|
[this, &writes](uint64_t i, size_t /*thread*/) {
|
||||||
const BlobRange& range = writes[i].range;
|
const BlobRange& range = writes[i].range;
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
||||||
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
||||||
|
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||||
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
|
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
|
||||||
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
|
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
|
||||||
std::to_string(i).c_str());
|
std::to_string(i).c_str());
|
||||||
|
|
|
||||||
|
|
@ -1126,7 +1126,7 @@ void TestAllDot() {
|
||||||
std::array<DotStats, kMaxWorkers> all_stats;
|
std::array<DotStats, kMaxWorkers> all_stats;
|
||||||
|
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||||
[&](size_t rep, size_t thread) {
|
[&](size_t rep, size_t thread) {
|
||||||
float* HWY_RESTRICT pa = a.Row(thread);
|
float* HWY_RESTRICT pa = a.Row(thread);
|
||||||
float* HWY_RESTRICT pb = b.Row(thread);
|
float* HWY_RESTRICT pb = b.Row(thread);
|
||||||
|
|
|
||||||
12
ops/matmul.h
12
ops/matmul.h
|
|
@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
||||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||||
|
|
||||||
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
// Policy classes for parallelism, implementing some of `Parallelism`.
|
||||||
|
|
||||||
struct MMParallelNone {
|
struct MMParallelNone {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
|
|
@ -220,14 +220,14 @@ struct MMParallelHierarchical {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
|
void DispatchParallelism(Parallelism parallelism, const Func& func,
|
||||||
Args&&... args) {
|
Args&&... args) {
|
||||||
switch (parallelism) {
|
switch (parallelism) {
|
||||||
case ParallelismStrategy::kNone:
|
case Parallelism::kNone:
|
||||||
return func(MMParallelNone(), std::forward<Args>(args)...);
|
return func(MMParallelNone(), std::forward<Args>(args)...);
|
||||||
case ParallelismStrategy::kWithinCluster:
|
case Parallelism::kWithinCluster:
|
||||||
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
|
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
|
||||||
case ParallelismStrategy::kHierarchical:
|
case Parallelism::kHierarchical:
|
||||||
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
|
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
|
|
@ -716,7 +716,7 @@ class MMOptions {
|
||||||
const void* opaque = nullptr;
|
const void* opaque = nullptr;
|
||||||
|
|
||||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
Parallelism parallelism = Parallelism::kHierarchical;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||||
|
|
|
||||||
|
|
@ -500,7 +500,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
activations.DebugCheckSameShape(out);
|
activations.DebugCheckSameShape(out);
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
|
||||||
cluster_idx, Callers::kOpsRMSNormBatched,
|
cluster_idx, Callers::kOpsRMSNormBatched,
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||||
|
|
@ -517,7 +517,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||||
Callers::kOpsRMSNormInplaceBatched,
|
Callers::kOpsRMSNormInplaceBatched,
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||||
|
|
@ -550,7 +550,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||||
size_t cluster_idx = 0) {
|
size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(out.SameShape(x));
|
HWY_DASSERT(out.SameShape(x));
|
||||||
ParallelFor(
|
ParallelFor(
|
||||||
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
|
||||||
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
|
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
|
||||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -1290,7 +1290,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||||
if (cap == 0.0f) return;
|
if (cap == 0.0f) return;
|
||||||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
|
||||||
Callers::kOpsMaybeLogitsSoftCapBatched,
|
Callers::kOpsMaybeLogitsSoftCapBatched,
|
||||||
[&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
if (non_eos.Get(task)) {
|
if (non_eos.Get(task)) {
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ struct ThreadingContext {
|
||||||
|
|
||||||
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
// Returns a worker index compatible with those from `ParallelFor`, assuming
|
||||||
// the current thread is running on one thread per cluster, which happens
|
// the current thread is running on one thread per cluster, which happens
|
||||||
// when `ParallelismStrategy` is `kAcrossClusters`.
|
// when `Parallelism` is `kAcrossClusters`.
|
||||||
size_t Worker(size_t cluster_idx) const {
|
size_t Worker(size_t cluster_idx) const {
|
||||||
return cluster_idx * pools.MaxWorkersPerCluster();
|
return cluster_idx * pools.MaxWorkersPerCluster();
|
||||||
}
|
}
|
||||||
|
|
@ -130,7 +130,7 @@ struct ThreadingContext {
|
||||||
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
|
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
|
||||||
|
|
||||||
// Describes the strategy for distributing parallel work across cores.
|
// Describes the strategy for distributing parallel work across cores.
|
||||||
enum class ParallelismStrategy : uint8_t {
|
enum class Parallelism : uint8_t {
|
||||||
// Execute using a single-threaded loop on the calling thread. The `worker`
|
// Execute using a single-threaded loop on the calling thread. The `worker`
|
||||||
// index passed to the user's `Func` is unique across clusters.
|
// index passed to the user's `Func` is unique across clusters.
|
||||||
kNone,
|
kNone,
|
||||||
|
|
@ -245,19 +245,19 @@ void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
|
||||||
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
|
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
|
||||||
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
|
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
void ParallelFor(Parallelism parallelism, size_t num_tasks,
|
||||||
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
|
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
|
||||||
const Func& func) {
|
const Func& func) {
|
||||||
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
|
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
|
||||||
if (cluster_idx != 0) {
|
if (cluster_idx != 0) {
|
||||||
// If already running across clusters, only use within-cluster modes.
|
// If already running across clusters, only use within-cluster modes.
|
||||||
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
|
HWY_DASSERT(parallelism == Parallelism::kNone ||
|
||||||
parallelism == ParallelismStrategy::kWithinCluster);
|
parallelism == Parallelism::kWithinCluster);
|
||||||
}
|
}
|
||||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||||
|
|
||||||
switch (parallelism) {
|
switch (parallelism) {
|
||||||
case ParallelismStrategy::kNone: {
|
case Parallelism::kNone: {
|
||||||
const size_t worker = ctx.Worker(cluster_idx);
|
const size_t worker = ctx.Worker(cluster_idx);
|
||||||
for (size_t task = 0; task < num_tasks; ++task) {
|
for (size_t task = 0; task < num_tasks; ++task) {
|
||||||
func(task, worker);
|
func(task, worker);
|
||||||
|
|
@ -265,16 +265,16 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
case ParallelismStrategy::kAcrossClusters:
|
case Parallelism::kAcrossClusters:
|
||||||
return ParallelForAcrossClusters(
|
return ParallelForAcrossClusters(
|
||||||
num_tasks, ctx, caller,
|
num_tasks, ctx, caller,
|
||||||
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||||
|
|
||||||
case ParallelismStrategy::kWithinCluster:
|
case Parallelism::kWithinCluster:
|
||||||
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
|
||||||
func);
|
func);
|
||||||
|
|
||||||
case ParallelismStrategy::kFlat:
|
case Parallelism::kFlat:
|
||||||
// Choose a single pool: the only cluster, or across all clusters
|
// Choose a single pool: the only cluster, or across all clusters
|
||||||
// (slower synchronization, but more memory bandwidth)
|
// (slower synchronization, but more memory bandwidth)
|
||||||
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
|
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
|
||||||
|
|
@ -286,7 +286,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
func(task, ctx.Worker(cluster_idx));
|
func(task, ctx.Worker(cluster_idx));
|
||||||
});
|
});
|
||||||
|
|
||||||
case ParallelismStrategy::kHierarchical:
|
case Parallelism::kHierarchical:
|
||||||
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
|
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue