Major cleanup of profiler zones, add Caller annotation for all pool.Run

Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
This commit is contained in:
Jan Wassenberg 2025-10-23 01:53:50 -07:00 committed by Copybara-Service
parent 9e8ac7e2f0
commit 3ed403e287
43 changed files with 811 additions and 592 deletions

View File

@ -95,7 +95,6 @@ cc_library(
":topology",
# Placeholder for container detection, do not remove
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//:topology",
],
@ -124,7 +123,9 @@ cc_library(
srcs = ["util/zones.cc"],
hdrs = ["util/zones.h"],
deps = [
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -258,7 +259,6 @@ cc_library(
"//io:fields",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -278,7 +278,6 @@ cc_library(
"//io:blob_store",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -309,7 +308,6 @@ cc_library(
deps = [
":allocator",
":basics",
":configs",
":mat",
":threading",
":threading_context",
@ -397,7 +395,6 @@ cc_library(
"@highway//:hwy",
"@highway//:math",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
@ -424,7 +421,6 @@ cc_test(
"@highway//:nanobenchmark", #buildcleaner: keep
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
)
@ -507,7 +503,6 @@ cc_test(
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee",
commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579",
remote = "https://github.com/google/highway",
)

View File

@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway)
```

View File

@ -120,9 +120,9 @@ cc_library(
":compress",
":distortion",
"//:mat",
"//:threading_context",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -180,6 +180,7 @@ cc_library(
":sfp",
"//:basics",
"//:mat",
"//:threading_context",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
@ -203,9 +204,9 @@ cc_test(
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"//:threading_context",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)

View File

@ -26,10 +26,10 @@
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
#if COMPRESS_STATS
#include <cmath> // lroundf
@ -493,13 +493,13 @@ struct CompressTraits<NuqStream> {
}
};
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
// which is useful for compressing sub-regions of an array.
// DEPRECATED: Use the overload with ThreadingContext instead.
template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work,
const PackedSpan<Packed>& packed,
const size_t packed_ofs, hwy::ThreadPool& pool) {
const size_t packed_ofs, hwy::ThreadPool& pool,
hwy::pool::Caller caller = hwy::pool::Caller()) {
packed.BoundsCheck(packed_ofs, num);
work.tls.resize(pool.NumWorkers());
if constexpr (COMPRESS_STATS) {
@ -511,7 +511,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(0, num_batches,
pool.Run(0, num_batches, caller,
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
const hn::ScalableTag<float> df;
@ -530,6 +530,17 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
}
}
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
// which is useful for compressing sub-regions of an array.
template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work,
const PackedSpan<Packed>& packed,
const size_t packed_ofs, ThreadingContext& ctx) {
Compress(raw, num, work, packed, packed_ofs, ctx.pools.Pool(),
ctx.pool_callers.Get(Callers::kCompress));
}
// Same as above, but without parallelization nor benchmarking.
template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,

View File

@ -24,9 +24,9 @@
#include "compression/compress.h"
#include "compression/distortion.h"
#include "util/test_util.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/hwy_gtest.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
@ -42,6 +42,17 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static ThreadingArgs SingleThreadArgs() {
ThreadingArgs args;
args.max_lps = 1;
return args;
}
static ThreadingContext& Ctx() {
static ThreadingContext* ctx = new ThreadingContext(SingleThreadArgs());
return *ctx;
}
// Calls Compress and Decompress2 and verifies the distortion/error.
template <typename Packed>
struct TestDecompress2 {
@ -49,7 +60,9 @@ struct TestDecompress2 {
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d);
CompressWorkingSet work;
hwy::ThreadPool pool(0);
ThreadingArgs args;
args.max_lps = 1;
ThreadingContext ctx(args);
hwy::RandomState rng;
const size_t num = 2 * N;
@ -68,7 +81,7 @@ struct TestDecompress2 {
// Short inputs fail VerifyGaussian.
const size_t packed_ofs = 0;
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
Compress(raw.get(), num, work, packed_span, packed_ofs, ctx);
hn::Vec<D> raw0, raw1;
Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1);
hn::Store(raw0, d, dec.get());
@ -129,7 +142,6 @@ struct TestShortLengths {
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d);
CompressWorkingSet work;
hwy::ThreadPool pool(0);
hwy::RandomState rng;
for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) {
@ -149,7 +161,7 @@ struct TestShortLengths {
// Short inputs fail VerifyGaussian.
const size_t packed_ofs = 0;
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
Compress(raw.get(), num, work, packed_span, packed_ofs, Ctx());
DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(),
num);

View File

@ -37,7 +37,6 @@
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
@ -57,9 +56,6 @@ class SbsWriterImpl : public ISbsWriter {
template <typename Packed>
void InsertT(const char* name, F32Span weights,
const TensorInfo& tensor_info) {
// TODO(janwas): 1D parallel-for.
hwy::ThreadPool& pool = ctx_.pools.Pool();
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
// and depending on the input values may require rescaling. Scaling is
@ -82,21 +78,20 @@ class SbsWriterImpl : public ISbsWriter {
// succeeds, but we only have 10 floats, not the full tensor.
if (weights.size() == 10 && mat.Extents().Area() != 10) {
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool);
/*packed_ofs=*/0, ctx_);
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
return;
}
HWY_ASSERT(weights.size() == mat.Extents().Area());
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool);
/*packed_ofs=*/0, ctx_);
writer_.Add(name, mat.Packed(), mat.PackedBytes());
}
public:
SbsWriterImpl(const std::string& sbs_path)
: ctx_(ThreadingArgs()),
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
: ctx_(ThreadingArgs()), writer_(gcpp::Path(sbs_path), ctx_) {}
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override {

View File

@ -23,7 +23,7 @@
// IWYU pragma: end_exports
#include "compression/compress.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "util/threading_context.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
@ -98,25 +98,26 @@ void ForeachActivationType3(D d) {
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
const Allocator& allocator, MatPadding padding,
hwy::ThreadPool& pool) {
MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
ThreadingContext& ctx) {
gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, allocator, padding);
ws.tls.resize(ctx.pools.MaxWorkers());
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1)
f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed;
@ -126,25 +127,26 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
// `f` swaps `r` and `c`.
template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
MatPadding padding,
hwy::ThreadPool& pool) {
ThreadingContext& ctx) {
gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, allocator, padding);
ws.tls.resize(ctx.pools.MaxWorkers());
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1)
f = -f; // Also generate some negative values.
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
// Arbitrary value, different from 1, must match `GenerateMat`.
compressed.SetScale(0.6f);

View File

@ -83,8 +83,8 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void CallSoftmax(Logits logits, hwy::Profiler& p) {
Softmax(logits, p, hwy::Profiler::GlobalIdx());
void CallSoftmax(Logits logits, ThreadingContext& ctx) {
Softmax(logits, ctx, hwy::Profiler::GlobalIdx());
}
} // namespace HWY_NAMESPACE
@ -109,7 +109,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
size_t /*worker*/) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx);
// We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -55,8 +55,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK));
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -75,7 +75,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
hwy::Profiler& p, const size_t worker,
ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
@ -88,10 +88,10 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
}
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
Rope(qk, qkv_dim / 2, inv_timescale, pos, ctx, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, ctx, worker);
}
}
@ -99,18 +99,16 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v,
float* HWY_RESTRICT att_out,
hwy::Profiler& p, const size_t worker) {
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
// TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
@ -118,7 +116,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} else {
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
@ -134,7 +133,7 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
@ -144,23 +143,23 @@ void SingleDotSoftmaxWeightedSum(
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, p, worker);
layer.layer_config.qkv_dim, ctx, worker);
});
}
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
const Logits logits(att, att_len);
MaybeLogitsSoftCap(att_cap, logits, p, worker);
Softmax(logits, p, worker, /*temperature=*/1.0f);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
worker);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
ctx, worker);
}
// The attention window usually starts at 0 unless `pos` is larger than
@ -174,12 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
static const auto root_zone =
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone =
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
@ -199,7 +193,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task);
PROFILER_ZONE3(ctx.profiler, worker, zone);
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
@ -231,16 +225,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, ctx.profiler,
worker);
layer, activations, att, att_out, ctx, worker);
};
{
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, kAcrossClusters is insufficient.
HierarchicalParallelFor(
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools,
func);
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx,
Callers::kAttDotSoftmaxWeightedSum, func);
}
}
@ -256,9 +249,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
AttentionActivations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
Zones::kGenAttentionComputeQKV);
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
@ -295,7 +287,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// tasks are very lightweight.
ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR {
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -316,12 +309,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
qkv_dim, env.ctx.profiler, worker);
qkv_dim, env.ctx, worker);
});
}
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
env.ctx.profiler, worker, pos, /*mul=*/1.0f);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
worker, pos, /*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
@ -332,9 +325,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
MatMulEnv& env) {
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
@ -352,9 +343,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
static const auto zone =
env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.

View File

@ -26,33 +26,33 @@
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
hwy::Profiler& p, size_t worker, size_t pos, \
float mul); \
\
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
\
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
QBatch& qbatch, ThreadingContext& ctx); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \
\
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
\
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
QBatch& qbatch, ThreadingContext& ctx); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the

View File

@ -61,13 +61,12 @@ static constexpr size_t kNFx8HTileSize = 8;
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
const size_t num_heads = q.Cols() / q_t.Rows();
const size_t batch_size = q.Rows() / qbatch_size;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ);
for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
@ -83,10 +82,10 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
}
};
{
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
// Better than kFlat.
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, func);
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
}
}
@ -96,12 +95,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
ThreadingContext& ctx) {
const auto zone =
GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
@ -115,18 +112,19 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
layer.layer_config.qkv_dim, ctx, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
pos, query_scale);
}
};
{
// kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, func);
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
func);
}
}
@ -158,10 +156,9 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att_out, hwy::Profiler& p,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
if (float cap = activations.config.att_cap; cap > 0.0f) {
@ -170,7 +167,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
}
float d = 1.0f;
// This is just a copy of the first token.
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
@ -276,9 +273,8 @@ void TileFlashAttention(
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
const DF df;
@ -430,9 +426,8 @@ void TileFlashAttention4(
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker,
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
@ -597,10 +592,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
static const auto root_zone = ctx.profiler.AddZone(
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
@ -653,7 +645,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// For each head/token/query, compute fused flash Q.K, softmax and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[kMaxNF];
// Offsets into att_out for each row in the tile.
@ -741,13 +733,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler,
worker);
activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(
activations.q, q_offsets, k, start_positions[offset], last_pos,
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, worker);
TileFlashAttention4(activations.q, q_offsets, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx, worker);
} else {
HWY_UNREACHABLE;
}
@ -757,7 +748,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx.profiler, worker);
ctx, worker);
}
}
};
@ -765,7 +756,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
{
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention,
func);
}
}

View File

@ -39,8 +39,8 @@ namespace gcpp {
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
size_t worker); \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \

View File

@ -47,9 +47,9 @@ namespace HWY_NAMESPACE {
// For use by Vit even if !GEMMA_FUSED_FFN.
template <typename T1, typename T2>
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation));
const T2* HWY_RESTRICT c2, const size_t count,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenActivation);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
@ -73,11 +73,11 @@ void ActivationBatched(
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(task),
static_cast<const T*>(nullptr), c1.Cols(),
ctx.profiler, worker);
static_cast<const T*>(nullptr), c1.Cols(), ctx,
worker);
});
}
@ -87,8 +87,8 @@ void ActivationBatched(
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
const IndexRange range_r,
const IndexRange range_c, const StridedViewBF C2,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused));
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenActivationFused);
const size_t cols = range_c.Num();
HWY_DASSERT(C2.Cols() == cols);
@ -119,16 +119,16 @@ HWY_NOINLINE void ActivationBatched(
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker);
ctx, worker);
});
} else { // No multiplier
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) {
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task),
static_cast<const typename Mat2::T*>(nullptr),
c1.Cols(), ctx.profiler, worker);
c1.Cols(), ctx, worker);
});
}
}
@ -153,9 +153,7 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) {
static const auto zone =
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW);
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
@ -163,8 +161,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
#if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) {
Activation(layer_config.activation, C1, range_r, range_c, C2,
env.ctx.profiler, worker);
Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx,
worker);
};
MMOptions options;
options.SetFunc(fused);

View File

@ -55,7 +55,7 @@
#include "io/io.h" // Path
#include "ops/matmul.h"
#include "paligemma/image.h"
#include "util/basics.h" // PROFILER_ZONE3
#include "util/basics.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
@ -138,9 +138,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
static const auto zone =
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
GCPP_ZONE(ctx, hwy::Profiler::GlobalIdx(), Zones::kGenEmbed);
// Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
@ -415,9 +413,7 @@ static void SampleAndStream(const ModelConfig& config,
MaybeObserve(runtime_config, activations, qbatch, -1);
{
static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
GCPP_ZONE(env.ctx, /*worker=*/0, Zones::kGenEmbeddingMatmul);
// Compute logits from last layer activations.
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits);
@ -431,7 +427,8 @@ static void SampleAndStream(const ModelConfig& config,
ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
/*cluster_idx=*/0, Callers::kSampleAndStream,
[&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return;
// We streamed all prefill tokens, but pos is still one behind
@ -469,8 +466,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTop1));
GCPP_ZONE(ctx, worker, Zones::kGenSampleTop1);
return Top1OfSoftmax(logits);
};
}
@ -478,14 +474,13 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// General case: Softmax with top-k sampling.
return [&](size_t qi, size_t pos, Logits logits,
size_t worker) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTopK));
GCPP_ZONE(ctx, worker, Zones::kGenSampleTopK);
// We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream);
return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, gen, runtime_config.temperature,
runtime_config.accept_token, ctx.profiler, worker);
return FusedSoftmaxAndSampleTopK(logits, runtime_config.top_k, gen,
runtime_config.temperature,
runtime_config.accept_token, ctx, worker);
};
}
@ -657,8 +652,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer(weights_path, pools.Pool());
void Gemma::Save(const Path& weights_path, ThreadingContext& ctx) const {
BlobWriter writer(weights_path, ctx);
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,

View File

@ -246,7 +246,7 @@ class Gemma {
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }
void Save(const Path& weights_path, NestedPools& pools) const;
void Save(const Path& weights_path, ThreadingContext& ctx) const;
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.

View File

@ -36,7 +36,6 @@
// IWYU pragma: end_exports
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {

View File

@ -90,39 +90,43 @@ class VitAttention {
ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
pool_.Run(0, num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed
// working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
pool_.Run(
0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
Softmax(C.RowSpan(task), env_.ctx.profiler, worker);
});
pool_.Run(0, num_tokens_, caller3_,
[&](uint64_t task, size_t worker)
HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); });
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
pool_.Run(
0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
}
}
@ -136,7 +140,7 @@ class VitAttention {
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_,
pool_.Run(0, layer_config_.heads * num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
@ -152,7 +156,7 @@ class VitAttention {
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker);
Softmax(Logits(head_att, seq_len), env_.ctx, worker);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
@ -185,7 +189,11 @@ class VitAttention {
layer_(layer),
layer_config_(layer.layer_config),
env_(env),
pool_(env_.ctx.pools.Pool(0)) {}
pool_(env_.ctx.pools.Pool(0)),
caller1_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax1)),
caller2_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax2)),
caller3_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax3)),
caller4_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax4)) {}
HWY_INLINE void operator()() {
ComputeQKV();
@ -204,6 +212,10 @@ class VitAttention {
const LayerConfig& layer_config_;
MatMulEnv& env_;
hwy::ThreadPool& pool_;
hwy::pool::Caller caller1_;
hwy::pool::Caller caller2_;
hwy::pool::Caller caller3_;
hwy::pool::Caller caller4_;
};
// Same as FFWNoVit, but with different layer members and no second
@ -333,7 +345,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
activations.x.Row(0), vit_model_dim, env.ctx.profiler,
activations.x.Row(0), vit_model_dim, env.ctx,
hwy::Profiler::GlobalIdx());
});
}

View File

@ -34,7 +34,6 @@
#include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
@ -150,7 +149,7 @@ void LayerWeightsPtrs::SplitAttW1() {
static void HWY_MAYBE_UNUSED InitAttWeightsI8(
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
ThreadingContext& ctx) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
@ -160,7 +159,8 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked);
mat_owners.back().AllocateFor(att_weights, ctx.allocator,
MatPadding::kPacked);
}
const size_t model_dim = layer_config.model_dim;
@ -188,10 +188,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
}
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(),
/*packed_ofs=*/0, pool);
/*packed_ofs=*/0, ctx);
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
@ -201,7 +200,7 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& gating_einsum_w1,
MatPtrT<I8Stream>& gating_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
ThreadingContext& ctx) {
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
@ -228,10 +227,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w1, allocator,
mat_owners.back().AllocateFor(gating_einsum_w1, ctx.allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w2, allocator,
mat_owners.back().AllocateFor(gating_einsum_w2, ctx.allocator,
MatPadding::kPacked);
}
@ -248,11 +247,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
float* w2_tmp = w_tmp.get() + split_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
pool);
ctx);
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
pool);
ctx);
gating_einsum_w1.SetScale(1.0f);
gating_einsum_w2.SetScale(1.0f);
@ -265,7 +263,7 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& qkv_einsum_w1,
MatPtrT<I8Stream>& qkv_einsum_w2,
std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
ThreadingContext& ctx) {
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors.
@ -291,10 +289,10 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator,
mat_owners.back().AllocateFor(qkv_einsum_w1, ctx.allocator,
MatPadding::kPacked);
mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator,
mat_owners.back().AllocateFor(qkv_einsum_w2, ctx.allocator,
MatPadding::kPacked);
}
@ -312,9 +310,8 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
float* w2_tmp = w_tmp.get() + w1_size;
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool);
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, ctx);
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, ctx);
qkv_einsum_w1.SetScale(1.0f);
qkv_einsum_w2.SetScale(1.0f);
@ -326,16 +323,16 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
ThreadingContext& ctx) {
if (attn_vec_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
MatPtrT<I8Stream> att_weights_i8(att_weights);
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
mat_owners, allocator);
mat_owners, ctx);
attn_vec_einsum_w = attn_vec_einsum_w_i8;
att_weights = att_weights_i8;
} else {
InitAttWeights(mat_owners, allocator);
InitAttWeights(mat_owners, ctx.allocator);
}
if (gating_einsum_w.GetType() == Type::kI8) {
@ -343,7 +340,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
gating_einsum_w2_i8, mat_owners, allocator);
gating_einsum_w2_i8, mat_owners, ctx);
gating_einsum_w = gating_einsum_w_i8;
gating_einsum_w1 = gating_einsum_w1_i8;
gating_einsum_w2 = gating_einsum_w2_i8;
@ -356,7 +353,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
qkv_einsum_w2_i8, mat_owners, allocator);
qkv_einsum_w2_i8, mat_owners, ctx);
qkv_einsum_w = qkv_einsum_w_i8;
qkv_einsum_w1 = qkv_einsum_w1_i8;
qkv_einsum_w2 = qkv_einsum_w2_i8;
@ -367,7 +364,8 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners) {
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
@ -399,10 +397,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
}
CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(),
/*packed_ofs=*/0, pool);
/*packed_ofs=*/0, ctx);
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
@ -435,13 +432,13 @@ void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx);
});
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx);
});
}
@ -529,8 +526,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
owners.resize(start + tensors.size());
// Allocate in parallel because faulting in large tensors is slow.
ctx.pools.Pool().Run(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
ParallelFor(
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
@ -587,14 +585,13 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
// Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical
: ParallelismStrategy::kFlat;
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
@ -679,12 +676,11 @@ static std::vector<IOBatch> MakeBatches(
static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches,
ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
// >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical,
batches.size(), ctx, /*cluster_idx=*/0,
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
/*cluster_idx=*/0, Callers::kReadBatches,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
const IOBatch& batch = batches[task];
const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file());

View File

@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void Fixup(std::vector<MatOwner>& mat_owners, const Allocator& allocator);
void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
private:
// Copies att_weights from `attn_vec_einsum_w`.

View File

@ -79,7 +79,6 @@ cc_library(
"//:threading_context",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -108,7 +107,6 @@ cc_binary(
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)

View File

@ -107,7 +107,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size());
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
cluster_idx, [&](size_t i, size_t /*thread*/) {
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes,
blobs[i].data());
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0,
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
[&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx);
@ -190,7 +190,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{};
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
[&](size_t i, size_t /*thread*/) {
Callers::kTest, [&](size_t i, size_t /*thread*/) {
const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
if (mismatches != 0) {

View File

@ -28,7 +28,6 @@
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
#include "hwy/profiler.h"
@ -469,8 +468,8 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
}
}
BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool)
: file_(OpenFileOrNull(filename, "w+")), pool_(pool) {
BlobWriter::BlobWriter(const Path& filename, ThreadingContext& ctx)
: file_(OpenFileOrNull(filename, "w+")), ctx_(ctx) {
if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
// Write a placeholder header to the beginning of the file. If append-only,
// we will later write a footer, else we will update the header.
@ -489,10 +488,13 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
static_cast<const uint8_t*>(data), writes);
hwy::ThreadPool null_pool(0);
hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_;
pool_or_serial.Run(
0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) {
const ParallelismStrategy strategy = file_->IsAppendOnly()
? ParallelismStrategy::kNone
: ParallelismStrategy::kFlat;
ParallelFor(
strategy, writes.size(), ctx_,
/*cluster_idx=*/0, Callers::kBlobWriter,
[this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range;
if (!file_->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]);

View File

@ -28,9 +28,9 @@
#include "io/io.h" // File, Path, MapPtr
#include "util/basics.h" // Tristate
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -117,7 +117,7 @@ class BlobReader {
// does not make sense to call the methods concurrently.
class BlobWriter {
public:
BlobWriter(const Path& filename, hwy::ThreadPool& pool);
BlobWriter(const Path& filename, ThreadingContext& ctx);
// Writes the blob to disk with padding for alignment. Aborts on error.
void Add(const std::string& key, const void* data, size_t bytes);
@ -129,7 +129,7 @@ class BlobWriter {
std::unique_ptr<File> file_;
std::vector<hwy::uint128_t> keys_;
std::vector<size_t> blob_sizes_;
hwy::ThreadPool& pool_;
ThreadingContext& ctx_;
// Current offset in the file used for writing.
int64_t curr_offset_ = 0;
};

View File

@ -38,7 +38,6 @@ class BlobStoreTest : public testing::Test {};
TEST(BlobStoreTest, TestReadWrite) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -52,7 +51,7 @@ TEST(BlobStoreTest, TestReadWrite) {
const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q");
BlobWriter writer(path, pool);
BlobWriter writer(path, ctx);
writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.Finalize();
@ -96,7 +95,6 @@ TEST(BlobStoreTest, TestReadWrite) {
TEST(BlobStoreTest, TestNumBlobs) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
hwy::RandomState rng;
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
@ -106,7 +104,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0);
const Path path(path_str);
BlobWriter writer(path, pool);
BlobWriter writer(path, ctx);
std::vector<std::string> keys;
keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs;
@ -130,26 +128,31 @@ TEST(BlobStoreTest, TestNumBlobs) {
BlobReader reader(path);
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str());
const BlobRange* range = reader.Find(keys[i]);
HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
HWY_ASSERT(reader.CallWithSpan<uint8_t>(
keys[i], [path_str, num_blobs, i, range,
&blobs](const hwy::Span<const uint8_t> span) {
HWY_ASSERT_EQ(blobs[i].size(), span.size());
const bool match1 = span[0] == static_cast<uint8_t>(i & 255);
// If size == 1, we don't have a second byte to check.
const bool match2 =
span.size() == 1 ||
span[span.size() - 1] == static_cast<uint8_t>(i >> 8);
if (!match1 || !match2) {
HWY_ABORT("%s num_blobs %zu blob %zu offset %zu is corrupted.",
path_str, num_blobs, i, range->offset);
}
}));
});
ParallelFor(
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
std::to_string(i).c_str());
const BlobRange* range = reader.Find(keys[i]);
HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
HWY_ASSERT(reader.CallWithSpan<uint8_t>(
keys[i], [path_str, num_blobs, i, range,
&blobs](const hwy::Span<const uint8_t> span) {
HWY_ASSERT_EQ(blobs[i].size(), span.size());
const bool match1 = span[0] == static_cast<uint8_t>(i & 255);
// If size == 1, we don't have a second byte to check.
const bool match2 =
span.size() == 1 ||
span[span.size() - 1] == static_cast<uint8_t>(i >> 8);
if (!match1 || !match2) {
HWY_ABORT(
"%s num_blobs %zu blob %zu offset %zu is corrupted.",
path_str, num_blobs, i, range->offset);
}
}));
});
close(fd);
unlink(path_str);

View File

@ -44,6 +44,6 @@ int main(int argc, char** argv) {
}
gcpp::GemmaEnv env(argc, argv);
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools);
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
return 0;
}

View File

@ -30,7 +30,6 @@
#include "ops/matmul.h"
#include "util/basics.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/nanobenchmark.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
@ -72,7 +71,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// M = A rows, K = A cols, N = C cols.
template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n");
}
@ -91,15 +89,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked);
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, N), env.ctx.allocator,
MatPadding::kPacked, pool);
add_storage =
GenerateMat<float>(Extents2D(1, N), MatPadding::kPacked, env.ctx);
add_storage.SetScale(1.0f);
}
MatStorageT<TA> a =
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool);
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(
B_extents, env.ctx.allocator, MatPadding::kOdd, pool);
MatStorageT<TA> a = GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx);
MatStorageT<TB> b_trans =
GenerateTransposedMat<TB>(B_extents, MatPadding::kOdd, env.ctx);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;

View File

@ -31,7 +31,6 @@
#include "util/test_util.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/stats.h"
#include "hwy/timer.h"
@ -922,9 +921,11 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
(void)ScaleWeights(raw, num);
}
hwy::ThreadPool pool(0); // num is too small for parallelization
ThreadingArgs threading_args;
threading_args.max_lps = 1; // num is too small for parallelization
ThreadingContext ctx(threading_args);
const size_t packed_ofs = 0;
Compress(raw, num, work, packed, packed_ofs, pool);
Compress(raw, num, work, packed, packed_ofs, ctx);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num);
@ -1125,7 +1126,7 @@ void TestAllDot() {
std::array<DotStats, kMaxWorkers> all_stats;
ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0,
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
[&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread);

View File

@ -291,7 +291,7 @@ class MMDecompress {
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
const auto zone = GetProfilerZone(Zones::kMMDecompressA);
const auto zone = env.ctx.profiler_zones.Get(Zones::kMMDecompressA);
const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
@ -879,9 +879,8 @@ class MMLoops {
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
PROFILER_ZONE3(args.env.ctx.profiler,
args.env.ctx.Worker(args.options.cluster_idx),
GetProfilerZone(Zones::kMMDispatch));
GCPP_ZONE(args.env.ctx, args.env.ctx.Worker(args.options.cluster_idx),
Zones::kMMDispatch);
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
@ -904,7 +903,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT);
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -940,7 +939,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_K);
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -976,7 +975,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_MT);
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0);
@ -1010,7 +1009,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_MT_K);
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT_K);
parallel.ForRangesMC_NC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
@ -1063,8 +1062,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
GetProfilerZone(Zones::kMMMatMul));
GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMMatMul);
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
@ -1124,8 +1122,7 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
MatPtrT<BF16>& C, MMOptions options) {
const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
GetProfilerZone(Zones::kMMTwoMatMul));
GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMTwoMatMul);
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.

View File

@ -111,6 +111,7 @@ struct MMParallelWithinCluster {
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(ranges_n, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForN),
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
});
@ -127,12 +128,14 @@ struct MMParallelWithinCluster {
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, base + worker);
});
} else {
ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, base + worker); });
}
@ -146,6 +149,7 @@ struct MMParallelWithinCluster {
cluster.Run(
range_mc.begin(), range_mc.end(),
ctx.pool_callers.Get(Callers::kMMClusterForMC),
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
}
};
@ -159,6 +163,7 @@ struct MMParallelHierarchical {
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
HWY_DASSERT(caller_cluster_idx == 0);
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
// Single cluster: parallel-for over static partition of `range_n`.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
@ -169,7 +174,7 @@ struct MMParallelHierarchical {
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
ranges_n, cluster,
ranges_n, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
@ -179,7 +184,7 @@ struct MMParallelHierarchical {
const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
ranges_n, all_clusters,
ranges_n, all_clusters, caller,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
@ -187,7 +192,7 @@ struct MMParallelHierarchical {
const IndexRangePartition worker_ranges = StaticPartition(
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(
worker_ranges, cluster,
worker_ranges, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + worker);
});
@ -203,6 +208,8 @@ struct MMParallelHierarchical {
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0);
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMHierForMCNC);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
// `all_clusters` is a pool with one worker per cluster in a package.
@ -215,12 +222,13 @@ struct MMParallelHierarchical {
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) {
ranges_nc, cluster, caller,
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, worker);
});
} else {
return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
ranges_mc, ranges_nc, cluster, caller,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, worker); });
}
@ -229,11 +237,11 @@ struct MMParallelHierarchical {
// Multiple clusters: N across clusters (both are usually the larger), and
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
ParallelizeOneRange(
ranges_nc, all_clusters,
ranges_nc, all_clusters, caller,
[&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base = ctx.Worker(cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
ParallelizeOneRange(ranges_mc, cluster,
ParallelizeOneRange(ranges_mc, cluster, caller,
[&](const IndexRange& range_mc, size_t worker) {
func(range_mc, range_nc, cluster_base + worker);
});
@ -244,7 +252,7 @@ struct MMParallelHierarchical {
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t caller_cluster_idx, const Func& func) const {
HierarchicalParallelFor(range_mc.Num(), ctx.pools,
HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC,
[&](size_t task, size_t worker) {
func(range_mc.begin() + task, worker);
});
@ -811,7 +819,7 @@ class MMZone {
private:
uint64_t data_ = 0;
uint64_t data2_ = 0;
HWY_MEMBER_VAR_MAYBE_UNUSED uint64_t data2_ = 0;
};
#else
struct MMZone {

View File

@ -196,7 +196,7 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
get_col_c, all_clusters,
get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest),
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : all_rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r);
@ -221,7 +221,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>());
@ -233,11 +232,10 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
MatStorageT<TA> A(
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool));
MatStorageT<TA> A(GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx));
// Must be packed because we call Span() on it.
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, env.ctx.allocator,
MatPadding::kPacked, pool));
MatStorageT<TB> BT(
GenerateTransposedMat<TB>(B_extents, MatPadding::kPacked, env.ctx));
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
@ -246,8 +244,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
MatPadding::kPacked, pool)
add ? GenerateMat<float>(Extents2D(1, cols_bc), MatPadding::kPacked,
env.ctx)
: MatStorageT<float>("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked);
add_storage.SetScale(1.0f);

View File

@ -205,9 +205,9 @@ namespace detail {
// Shared by RMSNorm and RMSNormInplace.
template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul));
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormMul);
const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
@ -218,19 +218,17 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
} // namespace detail
template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const WT* HWY_RESTRICT weight,
const size_t w_ofs,
OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm));
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, const size_t w_ofs,
OT* HWY_RESTRICT out, const size_t size, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNorm);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker));
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, ctx, worker));
const VF* HWY_RESTRICT pmul = &mul;
Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs,
@ -245,13 +243,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
const size_t size, ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormInplace);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker));
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, ctx, worker));
const VF* HWY_RESTRICT pmul = &mul;
Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs,
@ -359,9 +357,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
// This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope));
const float* HWY_RESTRICT inv_timescale, const int pos,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsRope);
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -418,9 +416,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy));
const float* HWY_RESTRICT inv_timescale, const int pos,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsRopeAndMulBy);
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -480,9 +478,9 @@ template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
float* HWY_RESTRICT out,
const size_t size,
hwy::Profiler& p,
ThreadingContext& ctx,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom));
GCPP_ZONE(ctx, worker, Zones::kOpsAddFrom);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
@ -503,10 +501,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
cluster_idx, [&](uint64_t token_idx, size_t worker) {
cluster_idx, Callers::kOpsRMSNormBatched,
[&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
/*w_ofs=*/0, out.Row(token_idx), activations.Cols(),
ctx.profiler, worker);
ctx, worker);
});
});
}
@ -519,10 +518,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
Callers::kOpsRMSNormInplaceBatched,
[&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
inout.Row(token_idx), inout.Cols(),
ctx.profiler, worker);
inout.Row(token_idx), inout.Cols(), ctx,
worker);
});
});
}
@ -549,13 +549,14 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
ThreadingContext& ctx,
size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x));
ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
[&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
ctx.profiler, worker);
});
ParallelFor(
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
});
}
// No profiler zone because this is short and frequently called.
template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const size_t size) {
@ -575,8 +576,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo));
const size_t size, ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsMulByConstTo);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
@ -1121,10 +1122,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
// See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2.
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
const size_t worker,
float temperature = 1.0f) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax));
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
HWY_DASSERT(logits.size() != 0);
namespace hn = hwy::HWY_NAMESPACE;
@ -1256,8 +1257,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap));
ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsLogitsSoftCap);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
@ -1277,9 +1279,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
// Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, Logits logits, hwy::Profiler& p, const size_t worker) {
const float cap, Logits logits, ThreadingContext& ctx,
const size_t worker) {
if (cap != 0.0f) {
LogitsSoftCap(cap, logits, p, worker);
LogitsSoftCap(cap, logits, ctx, worker);
}
}
@ -1288,9 +1291,10 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return;
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
Callers::kOpsMaybeLogitsSoftCapBatched,
[&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) {
LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker);
LogitsSoftCap(cap, x.RowSpan(task), ctx, worker);
}
});
}
@ -1371,7 +1375,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k,
template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
Logits logits, size_t k, RngStream& gen, float temperature,
TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) {
TAcceptToken& accept_token, ThreadingContext& ctx, size_t worker) {
// Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits.
@ -1384,7 +1388,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
}
const size_t mask = token_logits.size();
Softmax(Logits(topk_logits.data(), mask), p, worker, temperature);
Softmax(Logits(topk_logits.data(), mask), ctx, worker, temperature);
auto distribution = std::discrete_distribution<int>(
std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen);

View File

@ -58,6 +58,11 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
static ThreadingContext& Ctx() {
static ThreadingContext* ctx = new ThreadingContext(ThreadingArgs());
return *ctx;
}
static RngStream MakeRng() {
static AesCtrEngine engine(/*deterministic=*/true);
static uint64_t stream = 0;
@ -133,8 +138,7 @@ class TestAddFrom {
}
SimpleAddFrom(o, e, count);
InitProfilerZones(hwy::Profiler::Get());
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
AddFrom(o, x, count, Ctx(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
@ -182,7 +186,6 @@ class TestMulByConstAndAdd {
T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -231,7 +234,6 @@ class TestMulByConst {
T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -278,8 +280,7 @@ struct TestMulByConstTo {
hwy::ConvertScalarTo<float>(constant));
}
InitProfilerZones(hwy::Profiler::Get());
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
MulByConstTo(constant, x, actual, count, Ctx(),
/*worker=*/0);
hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET),
@ -315,8 +316,7 @@ class TestSoftmax {
}
SimpleSoftmax(e, count);
InitProfilerZones(hwy::Profiler::Get());
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
Softmax(Logits(x, count), Ctx(), /*worker=*/0);
T sum = 0.0f;
for (size_t i = 0; i < count; ++i) {
@ -440,9 +440,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
}
void TestRopeAndMulBy() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler;
ThreadingContext& ctx = Ctx();
const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
@ -476,7 +474,7 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual);
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
worker);
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
@ -487,7 +485,7 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual);
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
}
@ -498,10 +496,10 @@ void TestRopeAndMulBy() {
CopyMat(x, kactual2);
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos);
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
worker);
static_assert(kmul == 1.0f, "");
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
@ -557,8 +555,7 @@ struct TestRMSNorm {
}
ScalarRMSNorm(vec, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, Ctx(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {
@ -593,8 +590,7 @@ struct TestRMSNormInplace {
}
ScalarRMSNorm(expected, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, Ctx(),
/*worker=*/0);
for (size_t i = 0; i < kSize; i++) {
@ -715,15 +711,14 @@ void TestAllLayerNorm() {
}
void TestSampleTopK() {
hwy::Profiler& p = hwy::Profiler::Get();
InitProfilerZones(p);
ThreadingContext& ctx = Ctx();
const size_t worker = 0;
const size_t kSize = 52;
std::vector<float> logits_vec(kSize);
Logits logits(logits_vec.data(), kSize);
// Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits, p, worker);
Softmax(logits, ctx, worker);
RngStream rng = MakeRng();
float temperature = 1.0f;
// SampleTopK<1> should return the argmax.
@ -736,7 +731,7 @@ void TestSampleTopK() {
EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits, p, worker);
Softmax(logits, ctx, worker);
// Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);

View File

@ -49,9 +49,10 @@ PinningPolicy::PinningPolicy(Tristate pin) {
static void MaybePin(const BoundedTopology& topology, size_t cluster_idx,
const BoundedTopology::Cluster& cluster,
PinningPolicy& pinning, hwy::ThreadPool& pool) {
static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("MaybePin");
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool.NumWorkers() <= lps.size());
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
pool.Run(0, pool.NumWorkers(), caller, [&](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
char buf[16]; // Linux limitation
@ -141,17 +142,20 @@ NestedPools::NestedPools(const BoundedTopology& topology,
// Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`.
all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx);
clusters_[cluster_idx] =
MakePool(allocator, workers_per_cluster[cluster_idx],
hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_),
tcluster.Node());
// Pin workers AND the calling thread from `all_clusters_`.
MaybePin(topology, cluster_idx, tcluster, pinning_,
*clusters_[cluster_idx]);
});
static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("NestedPools");
all_clusters_->Run(
0, num_clusters, caller, [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& tcluster =
topology.GetCluster(cluster_idx);
clusters_[cluster_idx] = MakePool(
allocator, workers_per_cluster[cluster_idx],
hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_),
tcluster.Node());
// Pin workers AND the calling thread from `all_clusters_`.
MaybePin(topology, cluster_idx, tcluster, pinning_,
*clusters_[cluster_idx]);
});
all_pinned_ = pinning_.AllPinned(&pin_string_);
}

View File

@ -266,9 +266,9 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
// index to a range.
template <class Func>
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
const Func& func) {
hwy::pool::Caller caller, const Func& func) {
const size_t num_tasks = get1.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
const IndexRange range1 = get1.Range(task);
func(range1, thread);
});
@ -282,11 +282,12 @@ void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
template <class Func>
void ParallelizeTwoRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
hwy::ThreadPool& pool, const Func& func) {
hwy::ThreadPool& pool, hwy::pool::Caller caller,
const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
@ -298,37 +299,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
});
}
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
const Func& func) {
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
func(task, thread);
});
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(
ranges, all_clusters,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools.Cluster(cluster_idx);
const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(),
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -78,31 +78,33 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
#endif
}
static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) {
hwy::ThreadPool& clusters = pools.AllClusters();
static void TunePools(hwy::PoolWaitMode wait_mode, ThreadingContext& ctx) {
hwy::ThreadPool& clusters = ctx.pools.AllClusters();
TunePool(wait_mode, clusters);
// Run in parallel because Turin CPUs have 16, and in real usage, we often
// run all at the same time.
clusters.Run(0, clusters.NumWorkers(),
ctx.pool_callers.Get(Callers::kTunePool),
[&](uint64_t cluster_idx, size_t /*thread*/) {
TunePool(wait_mode, pools.Cluster(cluster_idx));
TunePool(wait_mode, ctx.pools.Cluster(cluster_idx));
});
}
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
: profiler(hwy::Profiler::Get()),
profiler_zones(profiler),
pool_callers(),
topology(BoundedSlice(args.skip_packages, args.max_packages),
BoundedSlice(args.skip_clusters, args.max_clusters),
BoundedSlice(args.skip_lps, args.max_lps)),
cache_info(topology),
allocator(topology, cache_info, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) {
InitProfilerZones(profiler);
PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePools(hwy::PoolWaitMode::kSpin, pools);
TunePools(hwy::PoolWaitMode::kSpin, *this);
// kBlock is the default, hence set/tune it last.
TunePools(hwy::PoolWaitMode::kBlock, pools);
TunePools(hwy::PoolWaitMode::kBlock, *this);
}
} // namespace gcpp

View File

@ -28,6 +28,7 @@
#include "util/basics.h" // Tristate
#include "util/threading.h"
#include "util/topology.h"
#include "util/zones.h"
#include "hwy/profiler.h"
// IWYU pragma: end_exports
@ -107,6 +108,9 @@ struct ThreadingContext {
// Singleton; pass around a reference to reduce overhead.
hwy::Profiler& profiler;
ProfilerZones profiler_zones;
PoolCallers pool_callers;
// Detects topology, subject to limits imposed by user-specified `args`.
// For example, if `args.max_clusters` is 1, then `topology.NumClusters()`
// will be 1 regardless of the actual system topology.
@ -122,6 +126,9 @@ struct ThreadingContext {
NestedPools pools;
};
#define GCPP_ZONE(ctx, global_idx, zone_enum) \
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
// Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker`
@ -147,18 +154,53 @@ enum class ParallelismStrategy : uint8_t {
kHierarchical,
};
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
Callers callers, const Func& func) {
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = ctx.pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
func(task, thread);
});
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(ranges, all_clusters, caller,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster =
ctx.pools.Cluster(cluster_idx);
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(), caller,
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
}
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for
// `parallelism == kWithinCluster`, and should be 0 if unknown.
template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, const Func& func) {
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
parallelism == ParallelismStrategy::kWithinCluster);
}
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
switch (parallelism) {
case ParallelismStrategy::kNone: {
@ -171,7 +213,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
case ParallelismStrategy::kAcrossClusters:
return ctx.pools.AllClusters().Run(
0, num_tasks,
0, num_tasks, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: {
@ -179,7 +221,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
// used for TLS indexing for example in profiler.h.
const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, base + worker);
});
}
@ -191,19 +233,19 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks,
.Run(0, num_tasks, caller,
[&](uint64_t task, size_t worker) { func(task, worker); });
}
return ctx.pools.AllClusters().Run(
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
});
return all_clusters.Run(0, num_tasks, caller,
[&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
});
}
case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx.pools, func);
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
}
}

View File

@ -37,6 +37,8 @@ namespace {
using ::testing::ElementsAre;
static const hwy::pool::Caller kCaller = hwy::ThreadPool::AddCaller("Test");
TEST(ThreadingTest, TestBoundedSlice) {
const char* name = "test";
// No args = no limit.
@ -205,7 +207,7 @@ TEST(ThreadingTest, TestParallelizeOneRange) {
const IndexRangePartition partition = StaticPartition(range, 2, 4);
hwy::ThreadPool null_pool(0);
size_t calls = 0;
ParallelizeOneRange(partition, null_pool,
ParallelizeOneRange(partition, null_pool, kCaller,
[&](const IndexRange& range, size_t) {
if (++calls == 1) {
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
@ -226,7 +228,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
{
size_t calls = 0;
ParallelizeTwoRanges(
partition1, partition2, null_pool,
partition1, partition2, null_pool, kCaller,
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
@ -240,7 +242,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
{
size_t calls = 0;
ParallelizeTwoRanges(
partition2, partition1, null_pool,
partition2, partition1, null_pool, kCaller,
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
@ -265,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const double t0 = hwy::platform::Now();
for (size_t reps = 0; reps < 1200; ++reps) {
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread;
});
hwy::PreventElision(outputs[base]);
@ -305,18 +307,20 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
if (have_stop) {
for (size_t rep = 0; rep < max_reps; ++rep) {
const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread;
});
pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread;
});
const uint64_t t1 = hwy::timer::Stop();
times.push_back(t1 - t0);
}
} else {
for (size_t rep = 0; rep < max_reps; ++rep) {
const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread;
});
pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread;
});
const uint64_t t1 = hwy::timer::Start();
times.push_back(t1 - t0);
}

View File

@ -1,70 +1,201 @@
#include "util/zones.h"
#include <stddef.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
namespace gcpp {
namespace {
#if PROFILER_ENABLED
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones);
static const char* kProfilerZoneNames[kNumZones] = {
// Keep in sync with Zones enum.
"Ops.RMSNormMul",
"Ops.RMSNorm",
"Ops.RMSNormInplace",
"Ops.Rope",
"Ops.RopeAndMulBy",
"Ops.AddFrom",
"Ops.MulByConst",
"Ops.MulByConstTo",
"Ops.MulByConstAndAdd",
"Ops.MulByConstAndAddTile",
"Ops.MulByConstAndAddTile4",
"Ops.MulByConstAndAddVector",
"Ops.Softmax",
"Ops.LogitsSoftCap",
"FlashAttention.TransposeQ",
"FlashAttention.RMSNormAndPositionalEncoding",
"FlashAttention.SingleFlashAttention",
"FlashAttention.TileFlashAttention",
"FlashAttention.TileFlashAttention4",
"FlashAttention.FlashAttention",
"Gen.Activation",
"Gen.ActivationFused",
"Gen.SampleTop1",
"Gen.SampleTopK",
"Gen.Attention.QDotK",
"Gen.Attention.DotSoftmaxWeightedSum.par",
"Startup.Weights.ReadAllToBF16",
"Startup.Weights.ReadBatches",
"MM.Dispatch",
"MM.MatMul",
"MM.TwoMatMul",
"MM.DecompressA",
"MM.NT",
"MM.NT_K",
"MM.NT_MT",
"MM.NT_MT_K",
};
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones];
#endif
void InitProfilerZones(hwy::Profiler& profiler) {
#if PROFILER_ENABLED
// Initialize the zone handles. This is done once at startup.
for (size_t i = 0; i < kNumZones; ++i) {
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]);
const char* ZoneName(Zones zone) {
switch (zone) {
case Zones::kFlashAttentionFlashAttention:
return "FlashAttention.FlashAttention";
case Zones::kFlashAttentionInclusive:
return "FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionSingleFlashAttention:
return "FlashAttention.SingleFlashAttention";
case Zones::kFlashAttentionTileFlashAttention:
return "FlashAttention.TileFlashAttention";
case Zones::kFlashAttentionTileFlashAttention4:
return "FlashAttention.TileFlashAttention4";
case Zones::kFlashAttentionTransposeQ:
return "FlashAttention.TransposeQ";
case Zones::kGenActivation:
return "Gen.Activation";
case Zones::kGenActivationFused:
return "Gen.ActivationFused";
case Zones::kGenAttention:
return "Gen.Attention";
case Zones::kGenAttentionComputeQKV:
return "Gen.Attention.ComputeQKV";
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
return "Gen.Attention.DotSoftmaxWeightedSumInclusive";
case Zones::kGenAttentionDotSoftmaxWeightedSumPar:
return "Gen.Attention.DotSoftmaxWeightedSum.par";
case Zones::kGenAttentionQDotK:
return "Gen.Attention.QDotK";
case Zones::kGenAttentionSumHeads:
return "Gen.Attention.SumHeads";
case Zones::kGenEmbed:
return "Gen.Embed";
case Zones::kGenEmbeddingMatmul:
return "Gen.EmbeddingMatmul";
case Zones::kGenFFW:
return "Gen.FFW";
case Zones::kGenSampleTop1:
return "Gen.SampleTop1";
case Zones::kGenSampleTopK:
return "Gen.SampleTopK";
case Zones::kMMDecompressA:
return "MM.DecompressA";
case Zones::kMMDispatch:
return "MM.Dispatch";
case Zones::kMMMatMul:
return "MM.MatMul";
case Zones::kMMNT_K:
return "MM.NT_K";
case Zones::kMMNT_MT_K:
return "MM.NT_MT_K";
case Zones::kMMNT_MT:
return "MM.NT_MT";
case Zones::kMMNT:
return "MM.NT";
case Zones::kMMTwoMatMul:
return "MM.TwoMatMul";
case Zones::kOpsAddFrom:
return "Ops.AddFrom";
case Zones::kOpsLogitsSoftCap:
return "Ops.LogitsSoftCap";
// case Zones::kOpsMulByConst: // removed due to overhead
// case Zones::kOpsMulByConstAndAdd: // removed due to overhead
// case Zones::kOpsMulByConstAndAddTile: // removed due to overhead
// case Zones::kOpsMulByConstAndAddTile4: // removed due to overhead
// case Zones::kOpsMulByConstAndAddVector: // removed due to overhead
case Zones::kOpsMulByConstTo:
return "Ops.MulByConstTo";
case Zones::kOpsRmsNorm:
return "Ops.RMSNorm";
case Zones::kOpsRmsNormInplace:
return "Ops.RMSNormInplace";
case Zones::kOpsRmsNormMul:
return "Ops.RMSNormMul";
case Zones::kOpsRope:
return "Ops.Rope";
case Zones::kOpsRopeAndMulBy:
return "Ops.RopeAndMulBy";
case Zones::kOpsSoftmax:
return "Ops.Softmax";
case Zones::kStartupWeightsReadAllToBF16:
return "Startup.Weights.ReadAllToBF16";
case Zones::kStartupWeightsReadBatches:
return "Startup.Weights.ReadBatches";
default:
HWY_ABORT("Invalid zone %d.", static_cast<int>(zone));
}
#endif
}
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) {
#if PROFILER_ENABLED
return profiler_zone_handles[static_cast<size_t>(zone)];
#else
return hwy::profiler::ZoneHandle();
#endif
hwy::ProfilerFlags ZoneFlags(Zones zone) {
switch (zone) {
case Zones::kFlashAttentionInclusive:
case Zones::kGenAttention:
case Zones::kGenAttentionComputeQKV:
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
case Zones::kGenAttentionSumHeads:
case Zones::kGenEmbed:
case Zones::kGenEmbeddingMatmul:
case Zones::kGenFFW:
return hwy::ProfilerFlags::kInclusive;
default:
return hwy::ProfilerFlags::kDefault;
}
}
const char* CallerName(Callers caller) {
switch (caller) {
case Callers::kActivationBatched:
return "ActivationBatched";
case Callers::kAllocateAndBindAll:
return "AllocateAndBindAll";
case Callers::kAttComputeQKV:
return "Att.ComputeQKV";
case Callers::kAttDotSoftmaxWeightedSum:
return "Att.DotSoftmaxWeightedSum";
case Callers::kBlobWriter:
return "BlobWriter";
case Callers::kCompress:
return "Compress";
case Callers::kFixupWeights:
return "FixupWeights";
case Callers::kFlashAttention:
return "FlashAttention";
case Callers::kFlashRMSNormAndPositionalEncoding:
return "Flash.RMSNormAndPositionalEncoding";
case Callers::kFlashTransposeQ:
return "Flash.TransposeQ";
case Callers::kMMClusterForMC:
return "MM.ClusterForMC";
case Callers::kMMClusterForMCNC:
return "MM.ClusterForMCNC";
case Callers::kMMClusterForN:
return "MM.ClusterForN";
case Callers::kMMHierForMC:
return "MM.HierForMC";
case Callers::kMMHierForMCNC:
return "MM.HierForMCNC";
case Callers::kMMHierForN:
return "MM.HierForN";
case Callers::kOpsAddFromBatched:
return "Ops.AddFromBatched";
case Callers::kOpsMaybeLogitsSoftCapBatched:
return "Ops.MaybeLogitsSoftCapBatched";
case Callers::kOpsRMSNormBatched:
return "Ops.RMSNormBatched";
case Callers::kOpsRMSNormInplaceBatched:
return "Ops.RMSNormInplaceBatched";
case Callers::kReadAllToBF16:
return "ReadAllToBF16";
case Callers::kReadBatches:
return "ReadBatches";
case Callers::kSampleAndStream:
return "SampleAndStream";
case Callers::kTest: // only for unit tests.
return "Test-only!";
case Callers::kTunePool:
return "TunePool";
case Callers::kVitDotSoftmax1:
return "Vit.DotSoftmax1";
case Callers::kVitDotSoftmax2:
return "Vit.DotSoftmax2";
case Callers::kVitDotSoftmax3:
return "Vit.DotSoftmax3";
case Callers::kVitDotSoftmax4:
return "Vit.DotSoftmax4";
default:
HWY_ABORT("Invalid caller %d.", static_cast<int>(caller));
}
}
} // namespace
ProfilerZones::ProfilerZones(hwy::Profiler& profiler) {
for (size_t i = 0;; ++i) {
const Zones zone = static_cast<Zones>(i);
if (zone == Zones::kNumZones) break;
handles_[i] = profiler.AddZone(ZoneName(zone), ZoneFlags(zone));
}
}
PoolCallers::PoolCallers() {
for (size_t i = 0;; ++i) {
const Callers caller = static_cast<Callers>(i);
if (caller == Callers::kNumCallers) break;
callers_[i] = hwy::ThreadPool::AddCaller(CallerName(caller));
}
}
} // namespace gcpp

View File

@ -1,57 +1,123 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#include <stddef.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
namespace gcpp {
// Zones for the profiler.
enum class Zones {
kOpsRmsNormMul,
kOpsRmsNorm,
kOpsRmsNormInplace,
kOpsRope,
kOpsRopeAndMulBy,
kOpsAddFrom,
kOpsMulByConst,
kOpsMulByConstTo,
kOpsMulByConstAndAdd,
kOpsMulByConstAndAddTile,
kOpsMulByConstAndAddTile4,
kOpsMulByConstAndAddVector,
kOpsSoftmax,
kOpsLogitsSoftCap,
kFlashAttentionTransposeQ,
enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention4,
kFlashAttentionFlashAttention,
kFlashAttentionTransposeQ,
kGenActivation,
kGenActivationFused,
kGenAttention,
kGenAttentionComputeQKV,
kGenAttentionDotSoftmaxWeightedSumInclusive,
kGenAttentionDotSoftmaxWeightedSumPar,
kGenAttentionQDotK,
kGenAttentionSumHeads,
kGenEmbed,
kGenEmbeddingMatmul,
kGenFFW,
kGenSampleTop1,
kGenSampleTopK,
kGenAttentionQDotK,
kGenAttentionDotSoftmaxWeightedSumPar,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kMMDecompressA,
kMMDispatch,
kMMMatMul,
kMMTwoMatMul,
kMMDecompressA,
kMMNT,
kMMNT_K,
kMMNT_MT,
kMMNT_MT_K,
kNumZones
kMMNT_MT,
kMMNT,
kMMTwoMatMul,
kOpsAddFrom,
kOpsLogitsSoftCap,
// kOpsMulByConst, // removed due to overhead
// kOpsMulByConstAndAdd, // removed due to overhead
// kOpsMulByConstAndAddTile, // removed due to overhead
// kOpsMulByConstAndAddTile4, // removed due to overhead
// kOpsMulByConstAndAddVector, // removed due to overhead
kOpsMulByConstTo,
kOpsRmsNorm,
kOpsRmsNormInplace,
kOpsRmsNormMul,
kOpsRope,
kOpsRopeAndMulBy,
kOpsSoftmax,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kNumZones // must be last
};
// Initializes the profiler zones. Must be called before any other profiler
// functions.
void InitProfilerZones(hwy::Profiler& profiler);
// Owned by ThreadingContext.
class ProfilerZones {
public:
ProfilerZones(hwy::Profiler& profiler);
// Returns the zone handle for the given zone enum value.
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone);
hwy::profiler::ZoneHandle Get(Zones zone) {
HWY_DASSERT(zone != Zones::kNumZones);
return handles_[static_cast<size_t>(zone)];
}
private:
hwy::profiler::ZoneHandle handles_[static_cast<size_t>(Zones::kNumZones)];
};
enum class Callers { // Keep sorted
kActivationBatched,
kAllocateAndBindAll,
kAttComputeQKV,
kAttDotSoftmaxWeightedSum,
kBlobWriter,
kCompress,
kFixupWeights,
kFlashAttention,
kFlashRMSNormAndPositionalEncoding,
kFlashTransposeQ,
kMMClusterForMC,
kMMClusterForMCNC,
kMMClusterForN,
kMMHierForMC,
kMMHierForMCNC,
kMMHierForN,
kOpsAddFromBatched,
kOpsMaybeLogitsSoftCapBatched,
kOpsRMSNormBatched,
kOpsRMSNormInplaceBatched,
kReadAllToBF16,
kReadBatches,
kSampleAndStream,
kTest, // only for unit tests.
kTunePool,
kVitDotSoftmax1,
kVitDotSoftmax2,
kVitDotSoftmax3,
kVitDotSoftmax4,
kNumCallers // must be last
};
// Owned by ThreadingContext.
class PoolCallers {
public:
PoolCallers();
hwy::pool::Caller Get(Callers caller) {
HWY_DASSERT(caller != Callers::kNumCallers);
return callers_[static_cast<size_t>(caller)];
}
private:
hwy::pool::Caller callers_[static_cast<size_t>(Callers::kNumCallers)];
};
} // namespace gcpp