Major cleanup of profiler zones, add Caller annotation for all pool.Run

Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
This commit is contained in:
Jan Wassenberg 2025-10-23 01:53:50 -07:00 committed by Copybara-Service
parent 9e8ac7e2f0
commit 3ed403e287
43 changed files with 811 additions and 592 deletions

View File

@ -95,7 +95,6 @@ cc_library(
":topology", ":topology",
# Placeholder for container detection, do not remove # Placeholder for container detection, do not remove
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//:topology", "@highway//:topology",
], ],
@ -124,7 +123,9 @@ cc_library(
srcs = ["util/zones.cc"], srcs = ["util/zones.cc"],
hdrs = ["util/zones.h"], hdrs = ["util/zones.h"],
deps = [ deps = [
"@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )
@ -258,7 +259,6 @@ cc_library(
"//io:fields", "//io:fields",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )
@ -278,7 +278,6 @@ cc_library(
"//io:blob_store", "//io:blob_store",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )
@ -309,7 +308,6 @@ cc_library(
deps = [ deps = [
":allocator", ":allocator",
":basics", ":basics",
":configs",
":mat", ":mat",
":threading", ":threading",
":threading_context", ":threading_context",
@ -397,7 +395,6 @@ cc_library(
"@highway//:hwy", "@highway//:hwy",
"@highway//:math", "@highway//:math",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort", "@highway//hwy/contrib/sort:vqsort",
], ],
) )
@ -424,7 +421,6 @@ cc_test(
"@highway//:nanobenchmark", #buildcleaner: keep "@highway//:nanobenchmark", #buildcleaner: keep
"@highway//:profiler", "@highway//:profiler",
"@highway//:stats", "@highway//:stats",
"@highway//:thread_pool",
], ],
) )
@ -507,7 +503,6 @@ cc_test(
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if ## Note: absl needs to be installed by sentencepiece. This will only happen if

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version. # Require a more recent version.
git_override( git_override(
module_name = "highway", module_name = "highway",
commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee", commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579",
remote = "https://github.com/google/highway", remote = "https://github.com/google/highway",
) )

View File

@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma) FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
``` ```

View File

@ -120,9 +120,9 @@ cc_library(
":compress", ":compress",
":distortion", ":distortion",
"//:mat", "//:mat",
"//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:thread_pool",
], ],
) )
@ -180,6 +180,7 @@ cc_library(
":sfp", ":sfp",
"//:basics", "//:basics",
"//:mat", "//:mat",
"//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",
@ -203,9 +204,9 @@ cc_test(
":test_util", ":test_util",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:test_util", "//:test_util",
"//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:thread_pool",
], ],
) )

View File

@ -26,10 +26,10 @@
#include "compression/compress.h" // IWYU pragma: export #include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h" #include "compression/distortion.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
#if COMPRESS_STATS #if COMPRESS_STATS
#include <cmath> // lroundf #include <cmath> // lroundf
@ -493,13 +493,13 @@ struct CompressTraits<NuqStream> {
} }
}; };
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`, // DEPRECATED: Use the overload with ThreadingContext instead.
// which is useful for compressing sub-regions of an array.
template <typename Packed> template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work, CompressWorkingSet& work,
const PackedSpan<Packed>& packed, const PackedSpan<Packed>& packed,
const size_t packed_ofs, hwy::ThreadPool& pool) { const size_t packed_ofs, hwy::ThreadPool& pool,
hwy::pool::Caller caller = hwy::pool::Caller()) {
packed.BoundsCheck(packed_ofs, num); packed.BoundsCheck(packed_ofs, num);
work.tls.resize(pool.NumWorkers()); work.tls.resize(pool.NumWorkers());
if constexpr (COMPRESS_STATS) { if constexpr (COMPRESS_STATS) {
@ -511,7 +511,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
using Traits = CompressTraits<hwy::RemoveConst<Packed>>; using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
constexpr size_t kBatch = 8192; constexpr size_t kBatch = 8192;
const size_t num_batches = hwy::DivCeil(num, kBatch); const size_t num_batches = hwy::DivCeil(num, kBatch);
pool.Run(0, num_batches, pool.Run(0, num_batches, caller,
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR { [&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -530,6 +530,17 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
} }
} }
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
// which is useful for compressing sub-regions of an array.
template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work,
const PackedSpan<Packed>& packed,
const size_t packed_ofs, ThreadingContext& ctx) {
Compress(raw, num, work, packed, packed_ofs, ctx.pools.Pool(),
ctx.pool_callers.Get(Callers::kCompress));
}
// Same as above, but without parallelization nor benchmarking. // Same as above, but without parallelization nor benchmarking.
template <typename Packed> template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,

View File

@ -24,9 +24,9 @@
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/distortion.h" #include "compression/distortion.h"
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -42,6 +42,17 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
static ThreadingArgs SingleThreadArgs() {
ThreadingArgs args;
args.max_lps = 1;
return args;
}
static ThreadingContext& Ctx() {
static ThreadingContext* ctx = new ThreadingContext(SingleThreadArgs());
return *ctx;
}
// Calls Compress and Decompress2 and verifies the distortion/error. // Calls Compress and Decompress2 and verifies the distortion/error.
template <typename Packed> template <typename Packed>
struct TestDecompress2 { struct TestDecompress2 {
@ -49,7 +60,9 @@ struct TestDecompress2 {
HWY_INLINE void operator()(T /*unused*/, D d) { HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0); ThreadingArgs args;
args.max_lps = 1;
ThreadingContext ctx(args);
hwy::RandomState rng; hwy::RandomState rng;
const size_t num = 2 * N; const size_t num = 2 * N;
@ -68,7 +81,7 @@ struct TestDecompress2 {
// Short inputs fail VerifyGaussian. // Short inputs fail VerifyGaussian.
const size_t packed_ofs = 0; const size_t packed_ofs = 0;
Compress(raw.get(), num, work, packed_span, packed_ofs, pool); Compress(raw.get(), num, work, packed_span, packed_ofs, ctx);
hn::Vec<D> raw0, raw1; hn::Vec<D> raw0, raw1;
Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1); Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1);
hn::Store(raw0, d, dec.get()); hn::Store(raw0, d, dec.get());
@ -129,7 +142,6 @@ struct TestShortLengths {
HWY_INLINE void operator()(T /*unused*/, D d) { HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0);
hwy::RandomState rng; hwy::RandomState rng;
for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) { for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) {
@ -149,7 +161,7 @@ struct TestShortLengths {
// Short inputs fail VerifyGaussian. // Short inputs fail VerifyGaussian.
const size_t packed_ofs = 0; const size_t packed_ofs = 0;
Compress(raw.get(), num, work, packed_span, packed_ofs, pool); Compress(raw.get(), num, work, packed_span, packed_ofs, Ctx());
DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(), DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(),
num); num);

View File

@ -37,7 +37,6 @@
#include "util/basics.h" #include "util/basics.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
@ -57,9 +56,6 @@ class SbsWriterImpl : public ISbsWriter {
template <typename Packed> template <typename Packed>
void InsertT(const char* name, F32Span weights, void InsertT(const char* name, F32Span weights,
const TensorInfo& tensor_info) { const TensorInfo& tensor_info) {
// TODO(janwas): 1D parallel-for.
hwy::ThreadPool& pool = ctx_.pools.Pool();
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info)); MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
// SFP and NUQ (which uses SFP for cluster centers) have a limited range // SFP and NUQ (which uses SFP for cluster centers) have a limited range
// and depending on the input values may require rescaling. Scaling is // and depending on the input values may require rescaling. Scaling is
@ -82,21 +78,20 @@ class SbsWriterImpl : public ISbsWriter {
// succeeds, but we only have 10 floats, not the full tensor. // succeeds, but we only have 10 floats, not the full tensor.
if (weights.size() == 10 && mat.Extents().Area() != 10) { if (weights.size() == 10 && mat.Extents().Area() != 10) {
Compress(weights.data(), weights.size(), working_set_, mat.Span(), Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool); /*packed_ofs=*/0, ctx_);
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10); writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
return; return;
} }
HWY_ASSERT(weights.size() == mat.Extents().Area()); HWY_ASSERT(weights.size() == mat.Extents().Area());
Compress(weights.data(), weights.size(), working_set_, mat.Span(), Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool); /*packed_ofs=*/0, ctx_);
writer_.Add(name, mat.Packed(), mat.PackedBytes()); writer_.Add(name, mat.Packed(), mat.PackedBytes());
} }
public: public:
SbsWriterImpl(const std::string& sbs_path) SbsWriterImpl(const std::string& sbs_path)
: ctx_(ThreadingArgs()), : ctx_(ThreadingArgs()), writer_(gcpp::Path(sbs_path), ctx_) {}
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
void Insert(const char* name, F32Span weights, Type type, void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override { const TensorInfo& tensor_info) override {

View File

@ -23,7 +23,7 @@
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "compression/compress.h" #include "compression/compress.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "util/threading_context.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
@ -98,25 +98,26 @@ void ForeachActivationType3(D d) {
// Generates inputs: deterministic, within max SfpStream range. // Generates inputs: deterministic, within max SfpStream range.
template <typename MatT> template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
const Allocator& allocator, MatPadding padding, ThreadingContext& ctx) {
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers()); ws.tls.resize(ctx.pools.MaxWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, allocator, padding); MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
float* HWY_RESTRICT row = raw.Row(r); Callers::kTest, [&](size_t r, size_t thread) {
for (size_t c = 0; c < extents.cols; c++) { float* HWY_RESTRICT row = raw.Row(r);
float f = static_cast<float>(r * extents.cols + c) * scale; for (size_t c = 0; c < extents.cols; c++) {
if ((r + c) & 1) f = -f; // Also generate some negative values. float f = static_cast<float>(r * extents.cols + c) * scale;
row[c] = f; if ((r + c) & 1)
} f = -f; // Also generate some negative values.
Compress(raw.Row(r), raw.Cols(), ws.tls[thread], row[c] = f;
MakeSpan(compressed.Row(r), extents.cols), }
/*packed_ofs=*/0); Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
}); MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
compressed.SetScale(0.6f); // Arbitrary value, different from 1. compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed; return compressed;
@ -126,25 +127,26 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
// `f` swaps `r` and `c`. // `f` swaps `r` and `c`.
template <typename MatT> template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents, MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
MatPadding padding, MatPadding padding,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers()); ws.tls.resize(ctx.pools.MaxWorkers());
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, allocator, padding); MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
float* HWY_RESTRICT row = raw.Row(r); Callers::kTest, [&](size_t r, size_t thread) {
for (size_t c = 0; c < extents.cols; c++) { float* HWY_RESTRICT row = raw.Row(r);
float f = static_cast<float>(c * extents.rows + r) * scale; for (size_t c = 0; c < extents.cols; c++) {
if ((r + c) & 1) f = -f; // Also generate some negative values. float f = static_cast<float>(c * extents.rows + r) * scale;
row[c] = f; if ((r + c) & 1)
} f = -f; // Also generate some negative values.
Compress(raw.Row(r), raw.Cols(), ws.tls[thread], row[c] = f;
MakeSpan(compressed.Row(r), extents.cols), }
/*packed_ofs=*/0); Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
}); MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
});
// Arbitrary value, different from 1, must match `GenerateMat`. // Arbitrary value, different from 1, must match `GenerateMat`.
compressed.SetScale(0.6f); compressed.SetScale(0.6f);

View File

@ -83,8 +83,8 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
void CallSoftmax(Logits logits, hwy::Profiler& p) { void CallSoftmax(Logits logits, ThreadingContext& ctx) {
Softmax(logits, p, hwy::Profiler::GlobalIdx()); Softmax(logits, ctx, hwy::Profiler::GlobalIdx());
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
@ -109,7 +109,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits, const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
size_t /*worker*/) -> TokenAndProb { size_t /*worker*/) -> TokenAndProb {
// input is logits, not yet probabilities // input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx);
// We are called for each token, but pos starts at 1. Clamping // We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun. // max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size()); HWY_ASSERT(pos < prompt.size());

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -55,8 +55,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q, const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att, const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
hwy::Profiler& p, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK)); GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -75,7 +75,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
void PositionalEncodingQK(float* qk, const size_t layer_idx, void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
hwy::Profiler& p, const size_t worker, ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) { const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk; const PostQKType& post_qk = layer.layer_config.post_qk;
@ -88,10 +88,10 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
} }
// PostQKType::Rope // PostQKType::Rope
if (post_qk == PostQKType::HalfRope) { if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); Rope(qk, qkv_dim / 2, inv_timescale, pos, ctx, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else { } else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, ctx, worker);
} }
} }
@ -99,18 +99,16 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// `att_out`. Equivalent in gemma/modules.py: // `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(const size_t start_pos, static HWY_INLINE void WeightedSumV(
const size_t last_pos, const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const float* HWY_RESTRICT att, const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const MatPtrT<KV_t>& v, const size_t worker) {
float* HWY_RESTRICT att_out,
hwy::Profiler& p, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B. // we supported non-transposed B.
// TODO: 2..4x unroll // TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
worker); worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
@ -118,7 +116,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} else { } else {
{ {
const size_t pos_mod = div_seq_len.Remainder(start_pos); const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker); MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
worker);
} }
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos); const size_t pos_mod = div_seq_len.Remainder(pos);
@ -134,7 +133,7 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att, const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = const size_t seq_len =
@ -144,23 +143,23 @@ void SingleDotSoftmaxWeightedSum(
if (layer.query_norm_scale.HasPtr()) { if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, p, worker); layer.layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos, PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
query_scale); query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
const Logits logits(att, att_len); const Logits logits(att, att_len);
MaybeLogitsSoftCap(att_cap, logits, p, worker); MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, p, worker, /*temperature=*/1.0f); Softmax(logits, ctx, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p, WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
worker); ctx, worker);
} }
// The attention window usually starts at 0 unless `pos` is larger than // The attention window usually starts at 0 unless `pos` is larger than
@ -174,12 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto root_zone = GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone =
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
@ -199,7 +193,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task); const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task); const size_t head = activations.div_heads.Remainder(task);
PROFILER_ZONE3(ctx.profiler, worker, zone); GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const size_t qi = div_qbatch.Remainder(tq_idx); const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx);
@ -231,16 +225,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, ctx.profiler, layer, activations, att, att_out, ctx, worker);
worker);
}; };
{ {
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, kAcrossClusters is insufficient. // Full parallelism is helpful, kAcrossClusters is insufficient.
HierarchicalParallelFor( HierarchicalParallelFor(
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools, num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx,
func); Callers::kAttDotSoftmaxWeightedSum, func);
} }
} }
@ -256,9 +249,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
AttentionActivations& activations, AttentionActivations& activations,
const QBatch& qbatch, const int flags, const QBatch& qbatch, const int flags,
MatMulEnv& env) { MatMulEnv& env) {
static const auto zone = env.ctx.profiler.AddZone( GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); Zones::kGenAttentionComputeQKV);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
@ -295,7 +287,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// tasks are very lightweight. // tasks are very lightweight.
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR { /*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -316,12 +309,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
if (layer.key_norm_scale.HasPtr()) { if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
qkv_dim, env.ctx.profiler, worker); qkv_dim, env.ctx, worker);
}); });
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
env.ctx.profiler, worker, pos, /*mul=*/1.0f); worker, pos, /*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });
@ -332,9 +325,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations, AttentionActivations& activations,
MatMulEnv& env) { MatMulEnv& env) {
static const auto zone = env.ctx.profiler.AddZone( GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT (void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
@ -352,9 +343,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) { MatMulEnv& env, int flags) {
static const auto zone = GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.

View File

@ -26,33 +26,33 @@
namespace gcpp { namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \ void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \ const AttentionActivations& activations, \
hwy::Profiler& p, size_t worker, size_t pos, \ ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \ float mul); \
\ \
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
\ \
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \ AttentionActivations& activations, \
QBatch& qbatch, ThreadingContext& ctx); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \ AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \ MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the // Function declarations for each SIMD target. Allows direct call from the

View File

@ -61,13 +61,12 @@ static constexpr size_t kNFx8HTileSize = 8;
// possible consecutive elements have the same KV. // possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t, static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) { const size_t qbatch_size, ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
// Group floats by the number of floats in a cache line. // Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
const size_t num_heads = q.Cols() / q_t.Rows(); const size_t num_heads = q.Cols() / q_t.Rows();
const size_t batch_size = q.Rows() / qbatch_size; const size_t batch_size = q.Rows() / qbatch_size;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ);
for (size_t lane = 0; lane < kNF; ++lane) { for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane; size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break; if (q_row >= q_t.Rows()) break;
@ -83,10 +82,10 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
} }
}; };
{ {
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
// Better than kFlat. // Better than kFlat.
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, func); /*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
} }
} }
@ -96,12 +95,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const auto zone =
GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
size_t qi = div_qbatch.Remainder(task); size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task); size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) { for (size_t h = 0; h < layer.layer_config.heads; ++h) {
@ -115,18 +112,19 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
if (layer.query_norm_scale.HasPtr()) { if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker); layer.layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
worker, pos, query_scale); pos, query_scale);
} }
}; };
{ {
// kHierarchical is not worth the extra sync overhead because the tasks are // kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight. // very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, func); /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
func);
} }
} }
@ -158,10 +156,9 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
float* HWY_RESTRICT att_out, hwy::Profiler& p, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
PROFILER_ZONE3(p, worker, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols()); float m = Dot(q, k.Row(pos_mod), k.Cols());
if (float cap = activations.config.att_cap; cap > 0.0f) { if (float cap = activations.config.att_cap; cap > 0.0f) {
@ -170,7 +167,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
} }
float d = 1.0f; float d = 1.0f;
// This is just a copy of the first token. // This is just a copy of the first token.
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker); MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos); const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols()); float x = Dot(q, k.Row(pos_mod), k.Cols());
@ -276,9 +273,8 @@ void TileFlashAttention(
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
constexpr int kHTileSize = kNFx8HTileSize; constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
@ -430,9 +426,8 @@ void TileFlashAttention4(
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -597,10 +592,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto root_zone = ctx.profiler.AddZone( GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx); layer, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
@ -653,7 +645,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// For each head/token/query, compute fused flash Q.K, softmax and weighted V. // For each head/token/query, compute fused flash Q.K, softmax and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
// Offsets into original Q for each row in the tile. // Offsets into original Q for each row in the tile.
uint32_t q_offsets[kMaxNF]; uint32_t q_offsets[kMaxNF];
// Offsets into att_out for each row in the tile. // Offsets into att_out for each row in the tile.
@ -741,13 +733,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
TileFlashAttention(activations.q, q_offsets, qT, k, TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, activations.att_out, out_offsets, ctx, worker);
worker);
} else if (kVTileSize == 4) { } else if (kVTileSize == 4) {
TileFlashAttention4( TileFlashAttention4(activations.q, q_offsets, k,
activations.q, q_offsets, k, start_positions[offset], last_pos, start_positions[offset], last_pos, min_last_pos,
min_last_pos, max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, worker); activations.att_out, out_offsets, ctx, worker);
} else { } else {
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
@ -757,7 +748,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
activations.q.Row(0) + q_offsets[offset], k, v, activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations, layer_idx, layer, activations,
activations.att_out.Row(0) + out_offsets[offset], activations.att_out.Row(0) + out_offsets[offset],
ctx.profiler, worker); ctx, worker);
} }
} }
}; };
@ -765,7 +756,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
{ {
PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient. // Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention,
func);
} }
} }

View File

@ -39,8 +39,8 @@ namespace gcpp {
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \ const AttentionActivations& activations, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, \ float* HWY_RESTRICT att_out, \
size_t worker); \ ThreadingContext& ctx, size_t worker); \
\ \
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \

View File

@ -47,9 +47,9 @@ namespace HWY_NAMESPACE {
// For use by Vit even if !GEMMA_FUSED_FFN. // For use by Vit even if !GEMMA_FUSED_FFN.
template <typename T1, typename T2> template <typename T1, typename T2>
void Activation(ActivationType activation, T1* HWY_RESTRICT c1, void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, const T2* HWY_RESTRICT c2, const size_t count,
const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation)); GCPP_ZONE(ctx, worker, Zones::kGenActivation);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -73,11 +73,11 @@ void ActivationBatched(
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
using T = typename Mat::T; using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works. // Cast to correct type so type deduction works.
Activation(activation, c1.Row(task), Activation(activation, c1.Row(task),
static_cast<const T*>(nullptr), c1.Cols(), static_cast<const T*>(nullptr), c1.Cols(), ctx,
ctx.profiler, worker); worker);
}); });
} }
@ -87,8 +87,8 @@ void ActivationBatched(
static inline void Activation(ActivationType activation, const RowPtrsBF C1, static inline void Activation(ActivationType activation, const RowPtrsBF C1,
const IndexRange range_r, const IndexRange range_r,
const IndexRange range_c, const StridedViewBF C2, const IndexRange range_c, const StridedViewBF C2,
hwy::Profiler& p, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused)); GCPP_ZONE(ctx, worker, Zones::kGenActivationFused);
const size_t cols = range_c.Num(); const size_t cols = range_c.Num();
HWY_DASSERT(C2.Cols() == cols); HWY_DASSERT(C2.Cols() == cols);
@ -119,16 +119,16 @@ HWY_NOINLINE void ActivationBatched(
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker); ctx, worker);
}); });
} else { // No multiplier } else { // No multiplier
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), Activation(activation, c1.Row(task),
static_cast<const typename Mat2::T*>(nullptr), static_cast<const typename Mat2::T*>(nullptr),
c1.Cols(), ctx.profiler, worker); c1.Cols(), ctx, worker);
}); });
} }
} }
@ -153,9 +153,7 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
static inline void FFWNoVit(const LayerWeightsPtrs& layer, static inline void FFWNoVit(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) { Activations& activations, MatMulEnv& env) {
static const auto zone = GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW);
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
@ -163,8 +161,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
#if GEMMA_FUSED_FFN #if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) { StridedViewBF C2, size_t worker) {
Activation(layer_config.activation, C1, range_r, range_c, C2, Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx,
env.ctx.profiler, worker); worker);
}; };
MMOptions options; MMOptions options;
options.SetFunc(fused); options.SetFunc(fused);

View File

@ -55,7 +55,7 @@
#include "io/io.h" // Path #include "io/io.h" // Path
#include "ops/matmul.h" #include "ops/matmul.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/basics.h" // PROFILER_ZONE3 #include "util/basics.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
@ -138,9 +138,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
MatStorageT<float>& x, ThreadingContext& ctx, MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) { size_t image_token_position = 0) {
static const auto zone = GCPP_ZONE(ctx, hwy::Profiler::GlobalIdx(), Zones::kGenEmbed);
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
// Image tokens just need to be copied. // Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM && if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
@ -415,9 +413,7 @@ static void SampleAndStream(const ModelConfig& config,
MaybeObserve(runtime_config, activations, qbatch, -1); MaybeObserve(runtime_config, activations, qbatch, -1);
{ {
static const auto zone = env.ctx.profiler.AddZone( GCPP_ZONE(env.ctx, /*worker=*/0, Zones::kGenEmbeddingMatmul);
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations. // Compute logits from last layer activations.
CallMatMul(activations.x_bf, weights.embedder_input_embedding, CallMatMul(activations.x_bf, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits); /*add=*/nullptr, env, activations.logits);
@ -431,7 +427,8 @@ static void SampleAndStream(const ModelConfig& config,
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, [&](size_t qi, size_t worker) { /*cluster_idx=*/0, Callers::kSampleAndStream,
[&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return; if (!non_eos.Get(qi)) return;
// We streamed all prefill tokens, but pos is still one behind // We streamed all prefill tokens, but pos is still one behind
@ -469,8 +466,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
if (runtime_config.top_k == 1 && !runtime_config.accept_token) { if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
HWY_ATTR -> TokenAndProb { HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, GCPP_ZONE(ctx, worker, Zones::kGenSampleTop1);
GetProfilerZone(Zones::kGenSampleTop1));
return Top1OfSoftmax(logits); return Top1OfSoftmax(logits);
}; };
} }
@ -478,14 +474,13 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// General case: Softmax with top-k sampling. // General case: Softmax with top-k sampling.
return [&](size_t qi, size_t pos, Logits logits, return [&](size_t qi, size_t pos, Logits logits,
size_t worker) HWY_ATTR -> TokenAndProb { size_t worker) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, GCPP_ZONE(ctx, worker, Zones::kGenSampleTopK);
GetProfilerZone(Zones::kGenSampleTopK));
// We want a different sequence for each batch element and position. // We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos; const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream); RngStream gen(engine, stream);
return FusedSoftmaxAndSampleTopK( return FusedSoftmaxAndSampleTopK(logits, runtime_config.top_k, gen,
logits, runtime_config.top_k, gen, runtime_config.temperature, runtime_config.temperature,
runtime_config.accept_token, ctx.profiler, worker); runtime_config.accept_token, ctx, worker);
}; };
} }
@ -657,8 +652,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
Gemma::~Gemma() = default; Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, NestedPools& pools) const { void Gemma::Save(const Path& weights_path, ThreadingContext& ctx) const {
BlobWriter writer(weights_path, pools.Pool()); BlobWriter writer(weights_path, ctx);
const std::vector<uint32_t> serialized_mat_ptrs = const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer); weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,

View File

@ -246,7 +246,7 @@ class Gemma {
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; } const InferenceArgs& Inference() const { return inference_; }
void Save(const Path& weights_path, NestedPools& pools) const; void Save(const Path& weights_path, ThreadingContext& ctx) const;
// `pos` is the position in the KV cache. Users are responsible for // `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn. // incrementing it in the `*StreamFunc`, or setting to zero for single-turn.

View File

@ -36,7 +36,6 @@
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {

View File

@ -90,39 +90,43 @@ class VitAttention {
ZeroInit(activations_.attention.att_out); ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) { for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { pool_.Run(0, num_tokens_, caller1_,
const size_t token = task; [&](uint64_t task, size_t worker) HWY_ATTR {
float* HWY_RESTRICT q = const size_t token = task;
activations_.attention.q.Row(token) + head * 3 * qkv_dim; float* HWY_RESTRICT q =
// TODO: shift to MatMul with A.scale once MatMul is confirmed working activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim); // TODO: shift to MatMul with A.scale once MatMul is confirmed
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); // working
}); MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(
const size_t seq_idx = task; 0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + const size_t seq_idx = task;
head * 3 * qkv_dim + qkv_dim; float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); head * 3 * qkv_dim + qkv_dim;
}); hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products // this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, env_, C); CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { pool_.Run(0, num_tokens_, caller3_,
Softmax(C.RowSpan(task), env_.ctx.profiler, worker); [&](uint64_t task, size_t worker)
}); HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); });
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { pool_.Run(
size_t token = task; 0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR {
float* HWY_RESTRICT att_out = size_t token = task;
activations_.attention.att_out.Row(token) + head * qkv_dim; float* HWY_RESTRICT att_out =
for (size_t i = 0; i < seq_len; ++i) { activations_.attention.att_out.Row(token) + head * qkv_dim;
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + for (size_t i = 0; i < seq_len; ++i) {
head * 3 * qkv_dim + 2 * qkv_dim; float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); head * 3 * qkv_dim + 2 * qkv_dim;
} MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}); }
});
} }
} }
@ -136,7 +140,7 @@ class VitAttention {
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Compute Q.K, softmax, and weighted V. // Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_, pool_.Run(0, layer_config_.heads * num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR { [&](uint64_t task, size_t worker) HWY_ATTR {
const size_t head = task % layer_config_.heads; const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads; const size_t token = task / layer_config_.heads;
@ -152,7 +156,7 @@ class VitAttention {
head_att[i] = Dot(q, k, qkv_dim); // score = q.k head_att[i] = Dot(q, k, qkv_dim); // score = q.k
} }
// SoftMax yields "probabilities" in head_att. // SoftMax yields "probabilities" in head_att.
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker); Softmax(Logits(head_att, seq_len), env_.ctx, worker);
// Compute weighted sum of v into att_out. // Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
@ -185,7 +189,11 @@ class VitAttention {
layer_(layer), layer_(layer),
layer_config_(layer.layer_config), layer_config_(layer.layer_config),
env_(env), env_(env),
pool_(env_.ctx.pools.Pool(0)) {} pool_(env_.ctx.pools.Pool(0)),
caller1_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax1)),
caller2_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax2)),
caller3_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax3)),
caller4_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax4)) {}
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
ComputeQKV(); ComputeQKV();
@ -204,6 +212,10 @@ class VitAttention {
const LayerConfig& layer_config_; const LayerConfig& layer_config_;
MatMulEnv& env_; MatMulEnv& env_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
hwy::pool::Caller caller1_;
hwy::pool::Caller caller2_;
hwy::pool::Caller caller3_;
hwy::pool::Caller caller4_;
}; };
// Same as FFWNoVit, but with different layer members and no second // Same as FFWNoVit, but with different layer members and no second
@ -333,7 +345,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection. // Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
activations.x.Row(0), vit_model_dim, env.ctx.profiler, activations.x.Row(0), vit_model_dim, env.ctx,
hwy::Profiler::GlobalIdx()); hwy::Profiler::GlobalIdx());
}); });
} }

View File

@ -34,7 +34,6 @@
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h" #include "util/zones.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -150,7 +149,7 @@ void LayerWeightsPtrs::SplitAttW1() {
static void HWY_MAYBE_UNUSED InitAttWeightsI8( static void HWY_MAYBE_UNUSED InitAttWeightsI8(
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w, const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners, MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) { ThreadingContext& ctx) {
if (!attn_vec_einsum_w.HasPtr()) return; if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8); HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
@ -160,7 +159,8 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
static std::mutex m; static std::mutex m;
std::lock_guard<std::mutex> lock(m); std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back(); mat_owners.emplace_back();
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked); mat_owners.back().AllocateFor(att_weights, ctx.allocator,
MatPadding::kPacked);
} }
const size_t model_dim = layer_config.model_dim; const size_t model_dim = layer_config.model_dim;
@ -188,10 +188,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
} }
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(), work, att_weights.Span(),
/*packed_ofs=*/0, pool); /*packed_ofs=*/0, ctx);
att_weights.SetScale(attn_vec_einsum_w.Scale()); att_weights.SetScale(attn_vec_einsum_w.Scale());
} }
@ -201,7 +200,7 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& gating_einsum_w1, MatPtrT<I8Stream>& gating_einsum_w1,
MatPtrT<I8Stream>& gating_einsum_w2, MatPtrT<I8Stream>& gating_einsum_w2,
std::vector<MatOwner>& mat_owners, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) { ThreadingContext& ctx) {
// Files have both or neither of w1 and w2. // Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file. // w is mutually exclusive with w1 and w2 in the file.
@ -228,10 +227,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
static std::mutex m; static std::mutex m;
std::lock_guard<std::mutex> lock(m); std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back(); mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w1, allocator, mat_owners.back().AllocateFor(gating_einsum_w1, ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
mat_owners.emplace_back(); mat_owners.emplace_back();
mat_owners.back().AllocateFor(gating_einsum_w2, allocator, mat_owners.back().AllocateFor(gating_einsum_w2, ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
} }
@ -248,11 +247,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
float* w2_tmp = w_tmp.get() + split_size; float* w2_tmp = w_tmp.get() + split_size;
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0, HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
pool); ctx);
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0, HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
pool); ctx);
gating_einsum_w1.SetScale(1.0f); gating_einsum_w1.SetScale(1.0f);
gating_einsum_w2.SetScale(1.0f); gating_einsum_w2.SetScale(1.0f);
@ -265,7 +263,7 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
MatPtrT<I8Stream>& qkv_einsum_w1, MatPtrT<I8Stream>& qkv_einsum_w1,
MatPtrT<I8Stream>& qkv_einsum_w2, MatPtrT<I8Stream>& qkv_einsum_w2,
std::vector<MatOwner>& mat_owners, std::vector<MatOwner>& mat_owners,
const Allocator& allocator) { ThreadingContext& ctx) {
// w is mutually exclusive with w1 in the file. // w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. // Done if we already read split tensors.
@ -291,10 +289,10 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
static std::mutex m; static std::mutex m;
std::lock_guard<std::mutex> lock(m); std::lock_guard<std::mutex> lock(m);
mat_owners.emplace_back(); mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator, mat_owners.back().AllocateFor(qkv_einsum_w1, ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
mat_owners.emplace_back(); mat_owners.emplace_back();
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator, mat_owners.back().AllocateFor(qkv_einsum_w2, ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
} }
@ -312,9 +310,8 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
float* w2_tmp = w_tmp.get() + w1_size; float* w2_tmp = w_tmp.get() + w1_size;
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0); HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, ctx);
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool); HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, ctx);
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
qkv_einsum_w1.SetScale(1.0f); qkv_einsum_w1.SetScale(1.0f);
qkv_einsum_w2.SetScale(1.0f); qkv_einsum_w2.SetScale(1.0f);
@ -326,16 +323,16 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
// TODO: exporters should bake this into the weights already. // TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock. // WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) { ThreadingContext& ctx) {
if (attn_vec_einsum_w.GetType() == Type::kI8) { if (attn_vec_einsum_w.GetType() == Type::kI8) {
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w); MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
MatPtrT<I8Stream> att_weights_i8(att_weights); MatPtrT<I8Stream> att_weights_i8(att_weights);
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8, InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
mat_owners, allocator); mat_owners, ctx);
attn_vec_einsum_w = attn_vec_einsum_w_i8; attn_vec_einsum_w = attn_vec_einsum_w_i8;
att_weights = att_weights_i8; att_weights = att_weights_i8;
} else { } else {
InitAttWeights(mat_owners, allocator); InitAttWeights(mat_owners, ctx.allocator);
} }
if (gating_einsum_w.GetType() == Type::kI8) { if (gating_einsum_w.GetType() == Type::kI8) {
@ -343,7 +340,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1); MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2); MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8, SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
gating_einsum_w2_i8, mat_owners, allocator); gating_einsum_w2_i8, mat_owners, ctx);
gating_einsum_w = gating_einsum_w_i8; gating_einsum_w = gating_einsum_w_i8;
gating_einsum_w1 = gating_einsum_w1_i8; gating_einsum_w1 = gating_einsum_w1_i8;
gating_einsum_w2 = gating_einsum_w2_i8; gating_einsum_w2 = gating_einsum_w2_i8;
@ -356,7 +353,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1); MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2); MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8, SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
qkv_einsum_w2_i8, mat_owners, allocator); qkv_einsum_w2_i8, mat_owners, ctx);
qkv_einsum_w = qkv_einsum_w_i8; qkv_einsum_w = qkv_einsum_w_i8;
qkv_einsum_w1 = qkv_einsum_w1_i8; qkv_einsum_w1 = qkv_einsum_w1_i8;
qkv_einsum_w2 = qkv_einsum_w2_i8; qkv_einsum_w2 = qkv_einsum_w2_i8;
@ -367,7 +364,8 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w, const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners) { MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
if (!attn_vec_einsum_w.HasPtr()) return; if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
@ -399,10 +397,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
} }
CompressWorkingSet work; CompressWorkingSet work;
hwy::ThreadPool pool(0);
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
work, att_weights.Span(), work, att_weights.Span(),
/*packed_ofs=*/0, pool); /*packed_ofs=*/0, ctx);
att_weights.SetScale(attn_vec_einsum_w.Scale()); att_weights.SetScale(attn_vec_einsum_w.Scale());
} }
@ -435,13 +432,13 @@ void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const size_t cluster_idx = 0; const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx.allocator); GetLayer(layer)->Fixup(mat_owners, ctx);
}); });
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
[&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx.allocator); VitLayer(layer)->Fixup(mat_owners, ctx);
}); });
} }
@ -529,8 +526,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
owners.resize(start + tensors.size()); owners.resize(start + tensors.size());
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
ctx.pools.Pool().Run( ParallelFor(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task]; TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -587,14 +585,13 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) { const BlobReader& reader, ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
// Especially TSAN is slow enough to warrant hierarchical parallelism. // Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical ? ParallelismStrategy::kHierarchical
: ParallelismStrategy::kFlat; : ParallelismStrategy::kFlat;
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) { Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
const TensorToRead& tensor = tensors[task]; const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -679,12 +676,11 @@ static std::vector<IOBatch> MakeBatches(
static void ReadBatches(const BlobReader& reader, static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical, ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
batches.size(), ctx, /*cluster_idx=*/0, /*cluster_idx=*/0, Callers::kReadBatches,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
const IOBatch& batch = batches[task]; const IOBatch& batch = batches[task];
const std::string& key = reader.Keys()[batch.KeyIdx()]; const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file()); const uint64_t bytes_read = batch.Read(reader.file());

View File

@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
// Must be called after reading weights via `ForEachTensor`. // Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already. // TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock. // WARNING: called from multiple threads; `mat_owners` requires a lock.
void Fixup(std::vector<MatOwner>& mat_owners, const Allocator& allocator); void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
private: private:
// Copies att_weights from `attn_vec_einsum_w`. // Copies att_weights from `attn_vec_einsum_w`.

View File

@ -79,7 +79,6 @@ cc_library(
"//:threading_context", "//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool",
], ],
) )
@ -108,7 +107,6 @@ cc_binary(
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:thread_pool",
], ],
) )

View File

@ -107,7 +107,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size());
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
cluster_idx, [&](size_t i, size_t /*thread*/) { cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size()); HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes, reader.file().Read(ranges[i].offset, ranges[i].bytes,
blobs[i].data()); blobs[i].data());
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters()); ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
[&](const size_t task, size_t cluster_idx) { [&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx); task ? blobs1 : blobs2, ctx, cluster_idx);
@ -190,7 +190,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
std::atomic<size_t> blobs_equal{}; std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{}; std::atomic<size_t> blobs_diff{};
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
[&](size_t i, size_t /*thread*/) { Callers::kTest, [&](size_t i, size_t /*thread*/) {
const size_t mismatches = const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]); BlobDifferences(blobs1[i], blobs2[i], keys[i]);
if (mismatches != 0) { if (mismatches != 0) {

View File

@ -28,7 +28,6 @@
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h" #include "hwy/detect_compiler_arch.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -469,8 +468,8 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
} }
} }
BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool) BlobWriter::BlobWriter(const Path& filename, ThreadingContext& ctx)
: file_(OpenFileOrNull(filename, "w+")), pool_(pool) { : file_(OpenFileOrNull(filename, "w+")), ctx_(ctx) {
if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str()); if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
// Write a placeholder header to the beginning of the file. If append-only, // Write a placeholder header to the beginning of the file. If append-only,
// we will later write a footer, else we will update the header. // we will later write a footer, else we will update the header.
@ -489,10 +488,13 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
static_cast<const uint8_t*>(data), writes); static_cast<const uint8_t*>(data), writes);
hwy::ThreadPool null_pool(0); const ParallelismStrategy strategy = file_->IsAppendOnly()
hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_; ? ParallelismStrategy::kNone
pool_or_serial.Run( : ParallelismStrategy::kFlat;
0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) { ParallelFor(
strategy, writes.size(), ctx_,
/*cluster_idx=*/0, Callers::kBlobWriter,
[this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range; const BlobRange& range = writes[i].range;
if (!file_->Write(writes[i].data, range.bytes, range.offset)) { if (!file_->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]); const std::string& key = StringFromKey(keys_[range.key_idx]);

View File

@ -28,9 +28,9 @@
#include "io/io.h" // File, Path, MapPtr #include "io/io.h" // File, Path, MapPtr
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -117,7 +117,7 @@ class BlobReader {
// does not make sense to call the methods concurrently. // does not make sense to call the methods concurrently.
class BlobWriter { class BlobWriter {
public: public:
BlobWriter(const Path& filename, hwy::ThreadPool& pool); BlobWriter(const Path& filename, ThreadingContext& ctx);
// Writes the blob to disk with padding for alignment. Aborts on error. // Writes the blob to disk with padding for alignment. Aborts on error.
void Add(const std::string& key, const void* data, size_t bytes); void Add(const std::string& key, const void* data, size_t bytes);
@ -129,7 +129,7 @@ class BlobWriter {
std::unique_ptr<File> file_; std::unique_ptr<File> file_;
std::vector<hwy::uint128_t> keys_; std::vector<hwy::uint128_t> keys_;
std::vector<size_t> blob_sizes_; std::vector<size_t> blob_sizes_;
hwy::ThreadPool& pool_; ThreadingContext& ctx_;
// Current offset in the file used for writing. // Current offset in the file used for writing.
int64_t curr_offset_ = 0; int64_t curr_offset_ = 0;
}; };

View File

@ -38,7 +38,6 @@ class BlobStoreTest : public testing::Test {};
TEST(BlobStoreTest, TestReadWrite) { TEST(BlobStoreTest, TestReadWrite) {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828}; static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -52,7 +51,7 @@ TEST(BlobStoreTest, TestReadWrite) {
const std::string keyA("0123456789abcdef"); // max 16 characters const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q"); const std::string keyB("q");
BlobWriter writer(path, pool); BlobWriter writer(path, ctx);
writer.Add(keyA, "DATA", 5); writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer)); writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.Finalize(); writer.Finalize();
@ -96,7 +95,6 @@ TEST(BlobStoreTest, TestReadWrite) {
TEST(BlobStoreTest, TestNumBlobs) { TEST(BlobStoreTest, TestNumBlobs) {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
hwy::RandomState rng; hwy::RandomState rng;
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) { for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
@ -106,7 +104,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0); HWY_ASSERT(fd > 0);
const Path path(path_str); const Path path(path_str);
BlobWriter writer(path, pool); BlobWriter writer(path, ctx);
std::vector<std::string> keys; std::vector<std::string> keys;
keys.reserve(num_blobs); keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs; std::vector<std::vector<uint8_t>> blobs;
@ -130,26 +128,31 @@ TEST(BlobStoreTest, TestNumBlobs) {
BlobReader reader(path); BlobReader reader(path);
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str()); ParallelFor(
const BlobRange* range = reader.Find(keys[i]); ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
HWY_ASSERT(range); Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_EQ(blobs[i].size(), range->bytes); HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
HWY_ASSERT(reader.CallWithSpan<uint8_t>( std::to_string(i).c_str());
keys[i], [path_str, num_blobs, i, range, const BlobRange* range = reader.Find(keys[i]);
&blobs](const hwy::Span<const uint8_t> span) { HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), span.size()); HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
const bool match1 = span[0] == static_cast<uint8_t>(i & 255); HWY_ASSERT(reader.CallWithSpan<uint8_t>(
// If size == 1, we don't have a second byte to check. keys[i], [path_str, num_blobs, i, range,
const bool match2 = &blobs](const hwy::Span<const uint8_t> span) {
span.size() == 1 || HWY_ASSERT_EQ(blobs[i].size(), span.size());
span[span.size() - 1] == static_cast<uint8_t>(i >> 8); const bool match1 = span[0] == static_cast<uint8_t>(i & 255);
if (!match1 || !match2) { // If size == 1, we don't have a second byte to check.
HWY_ABORT("%s num_blobs %zu blob %zu offset %zu is corrupted.", const bool match2 =
path_str, num_blobs, i, range->offset); span.size() == 1 ||
} span[span.size() - 1] == static_cast<uint8_t>(i >> 8);
})); if (!match1 || !match2) {
}); HWY_ABORT(
"%s num_blobs %zu blob %zu offset %zu is corrupted.",
path_str, num_blobs, i, range->offset);
}
}));
});
close(fd); close(fd);
unlink(path_str); unlink(path_str);

View File

@ -44,6 +44,6 @@ int main(int argc, char** argv) {
} }
gcpp::GemmaEnv env(argc, argv); gcpp::GemmaEnv env(argc, argv);
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools); env.GetGemma()->Save(args.output_weights, env.Env().ctx);
return 0; return 0;
} }

View File

@ -30,7 +30,6 @@
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/basics.h" #include "util/basics.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/nanobenchmark.h" #include "hwy/nanobenchmark.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -72,7 +71,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// M = A rows, K = A cols, N = C cols. // M = A rows, K = A cols, N = C cols.
template <typename TA, typename TB = TA, typename TC = float> template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
if (env.print_config || env.print_measurement) { if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -91,15 +89,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator, MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
if (add) { if (add) {
add_storage = GenerateMat<float>(Extents2D(1, N), env.ctx.allocator, add_storage =
MatPadding::kPacked, pool); GenerateMat<float>(Extents2D(1, N), MatPadding::kPacked, env.ctx);
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);
} }
MatStorageT<TA> a = MatStorageT<TA> a = GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx);
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool); MatStorageT<TB> b_trans =
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>( GenerateTransposedMat<TB>(B_extents, MatPadding::kOdd, env.ctx);
B_extents, env.ctx.allocator, MatPadding::kOdd, pool);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;

View File

@ -31,7 +31,6 @@
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/stats.h" #include "hwy/stats.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -922,9 +921,11 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
(void)ScaleWeights(raw, num); (void)ScaleWeights(raw, num);
} }
hwy::ThreadPool pool(0); // num is too small for parallelization ThreadingArgs threading_args;
threading_args.max_lps = 1; // num is too small for parallelization
ThreadingContext ctx(threading_args);
const size_t packed_ofs = 0; const size_t packed_ofs = 0;
Compress(raw, num, work, packed, packed_ofs, pool); Compress(raw, num, work, packed, packed_ofs, ctx);
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num); DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num);
@ -1125,7 +1126,7 @@ void TestAllDot() {
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
ParallelFor( ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
[&](size_t rep, size_t thread) { [&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread); float* HWY_RESTRICT pb = b.Row(thread);

View File

@ -291,7 +291,7 @@ class MMDecompress {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf); const size_t NBF = hn::Lanes(dbf);
const auto zone = GetProfilerZone(Zones::kMMDecompressA); const auto zone = env.ctx.profiler_zones.Get(Zones::kMMDecompressA);
const auto do_range = const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
@ -879,9 +879,8 @@ class MMLoops {
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B, static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
PROFILER_ZONE3(args.env.ctx.profiler, GCPP_ZONE(args.env.ctx, args.env.ctx.Worker(args.options.cluster_idx),
args.env.ctx.Worker(args.options.cluster_idx), Zones::kMMDispatch);
GetProfilerZone(Zones::kMMDispatch));
DispatchParallelism( DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR { args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
@ -904,7 +903,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -940,7 +939,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_K); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -976,7 +975,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_MT); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0); const IndexRange& range_kc = args.ranges_kc.Range(0);
@ -1010,7 +1009,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
const auto zone = GetProfilerZone(Zones::kMMNT_MT_K); const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT_K);
parallel.ForRangesMC_NC( parallel.ForRangesMC_NC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
@ -1063,8 +1062,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MatPtrT<TC>& C, MMOptions options = MMOptions()) { MatPtrT<TC>& C, MMOptions options = MMOptions()) {
const size_t cluster_idx = options.cluster_idx; const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size()); HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMMatMul);
GetProfilerZone(Zones::kMMMatMul));
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
@ -1124,8 +1122,7 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
MatPtrT<BF16>& C, MMOptions options) { MatPtrT<BF16>& C, MMOptions options) {
const size_t cluster_idx = options.cluster_idx; const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size()); HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMTwoMatMul);
GetProfilerZone(Zones::kMMTwoMatMul));
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.

View File

@ -111,6 +111,7 @@ struct MMParallelWithinCluster {
const IndexRangePartition ranges_n = StaticPartition( const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple); range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(ranges_n, cluster, ParallelizeOneRange(ranges_n, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForN),
[&](const IndexRange& worker_range, size_t worker) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker); func(worker_range, base + worker);
}); });
@ -127,12 +128,14 @@ struct MMParallelWithinCluster {
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster, ParallelizeOneRange(ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_nc, size_t worker) { [&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, base + worker); func(ranges_mc.Range(0), range_nc, base + worker);
}); });
} else { } else {
ParallelizeTwoRanges( ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, ranges_mc, ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, base + worker); }); size_t worker) { func(range_mc, range_nc, base + worker); });
} }
@ -146,6 +149,7 @@ struct MMParallelWithinCluster {
cluster.Run( cluster.Run(
range_mc.begin(), range_mc.end(), range_mc.begin(), range_mc.end(),
ctx.pool_callers.Get(Callers::kMMClusterForMC),
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
} }
}; };
@ -159,6 +163,7 @@ struct MMParallelHierarchical {
const Func& func) const { const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
HWY_DASSERT(caller_cluster_idx == 0); HWY_DASSERT(caller_cluster_idx == 0);
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
// Single cluster: parallel-for over static partition of `range_n`. // Single cluster: parallel-for over static partition of `range_n`.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
@ -169,7 +174,7 @@ struct MMParallelHierarchical {
const IndexRangePartition ranges_n = StaticPartition( const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple); range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange( return ParallelizeOneRange(
ranges_n, cluster, ranges_n, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker); func(worker_range, worker);
}); });
@ -179,7 +184,7 @@ struct MMParallelHierarchical {
const IndexRangePartition ranges_n = const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple); StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange( ParallelizeOneRange(
ranges_n, all_clusters, ranges_n, all_clusters, caller,
[&](const IndexRange& n_range, const size_t cluster_idx) { [&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx);
@ -187,7 +192,7 @@ struct MMParallelHierarchical {
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
n_range, cluster.NumWorkers() * inner_tasks, n_multiple); n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange( ParallelizeOneRange(
worker_ranges, cluster, worker_ranges, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + worker); func(worker_range, cluster_base + worker);
}); });
@ -203,6 +208,8 @@ struct MMParallelHierarchical {
HWY_MAYBE_UNUSED size_t caller_cluster_idx, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const { const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0); HWY_DASSERT(caller_cluster_idx == 0);
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMHierForMCNC);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
// `all_clusters` is a pool with one worker per cluster in a package. // `all_clusters` is a pool with one worker per cluster in a package.
@ -215,12 +222,13 @@ struct MMParallelHierarchical {
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange( return ParallelizeOneRange(
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) { ranges_nc, cluster, caller,
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, worker); func(ranges_mc.Range(0), range_nc, worker);
}); });
} else { } else {
return ParallelizeTwoRanges( return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, ranges_mc, ranges_nc, cluster, caller,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, worker); }); size_t worker) { func(range_mc, range_nc, worker); });
} }
@ -229,11 +237,11 @@ struct MMParallelHierarchical {
// Multiple clusters: N across clusters (both are usually the larger), and // Multiple clusters: N across clusters (both are usually the larger), and
// M within each cluster. We assume auto-tuning finds small MC/NC tasks. // M within each cluster. We assume auto-tuning finds small MC/NC tasks.
ParallelizeOneRange( ParallelizeOneRange(
ranges_nc, all_clusters, ranges_nc, all_clusters, caller,
[&](const IndexRange range_nc, size_t cluster_idx) { [&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base = ctx.Worker(cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
ParallelizeOneRange(ranges_mc, cluster, ParallelizeOneRange(ranges_mc, cluster, caller,
[&](const IndexRange& range_mc, size_t worker) { [&](const IndexRange& range_mc, size_t worker) {
func(range_mc, range_nc, cluster_base + worker); func(range_mc, range_nc, cluster_base + worker);
}); });
@ -244,7 +252,7 @@ struct MMParallelHierarchical {
template <class Func> template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t caller_cluster_idx, const Func& func) const { size_t caller_cluster_idx, const Func& func) const {
HierarchicalParallelFor(range_mc.Num(), ctx.pools, HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC,
[&](size_t task, size_t worker) { [&](size_t task, size_t worker) {
func(range_mc.begin() + task, worker); func(range_mc.begin() + task, worker);
}); });
@ -811,7 +819,7 @@ class MMZone {
private: private:
uint64_t data_ = 0; uint64_t data_ = 0;
uint64_t data2_ = 0; HWY_MEMBER_VAR_MAYBE_UNUSED uint64_t data2_ = 0;
}; };
#else #else
struct MMZone { struct MMZone {

View File

@ -196,7 +196,7 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const IndexRangePartition get_col_c = const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange( ParallelizeOneRange(
get_col_c, all_clusters, get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest),
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : all_rows_c) { for (size_t r : all_rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r); TC* HWY_RESTRICT C_row = C.Row(r);
@ -221,7 +221,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float> template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) { MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(), rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>()); TypeName<TC>());
@ -233,11 +232,10 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc); const Extents2D C_extents(rows_ac, cols_bc);
MatStorageT<TA> A( MatStorageT<TA> A(GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx));
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool));
// Must be packed because we call Span() on it. // Must be packed because we call Span() on it.
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, env.ctx.allocator, MatStorageT<TB> BT(
MatPadding::kPacked, pool)); GenerateTransposedMat<TB>(B_extents, MatPadding::kPacked, env.ctx));
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator, MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
MatPadding::kOdd); MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd); MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
@ -246,8 +244,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
C2.AllocateAndAttachRowPtrs(env.row_ptrs); C2.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage = MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator, add ? GenerateMat<float>(Extents2D(1, cols_bc), MatPadding::kPacked,
MatPadding::kPacked, pool) env.ctx)
: MatStorageT<float>("add", Extents2D(), env.ctx.allocator, : MatStorageT<float>("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);

View File

@ -205,9 +205,9 @@ namespace detail {
// Shared by RMSNorm and RMSNormInplace. // Shared by RMSNorm and RMSNormInplace.
template <typename VT> template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul)); GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormMul);
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
@ -218,19 +218,17 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
} // namespace detail } // namespace detail
template <typename XT, typename WT, typename OT> template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const WT* HWY_RESTRICT weight, const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, const size_t w_ofs,
const size_t w_ofs, OT* HWY_RESTRICT out, const size_t size, ThreadingContext& ctx,
OT* HWY_RESTRICT out, const size_t worker) {
const size_t size, hwy::Profiler& p, GCPP_ZONE(ctx, worker, Zones::kOpsRmsNorm);
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm));
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker)); const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, ctx, worker));
const VF* HWY_RESTRICT pmul = &mul; const VF* HWY_RESTRICT pmul = &mul;
Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs, Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs,
@ -245,13 +243,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
template <typename WT, typename XT> template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout, const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, hwy::Profiler& p, const size_t worker) { const size_t size, ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace)); GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormInplace);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker)); const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, ctx, worker));
const VF* HWY_RESTRICT pmul = &mul; const VF* HWY_RESTRICT pmul = &mul;
Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs, Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs,
@ -359,9 +357,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
// This overload is called if `post_qk == PostQKType::HalfRope`. // This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, const size_t dim_qkv, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const float* HWY_RESTRICT inv_timescale, const int pos,
const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope)); GCPP_ZONE(ctx, worker, Zones::kOpsRope);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -418,9 +416,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const float* HWY_RESTRICT inv_timescale, const int pos,
const size_t worker) { ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy)); GCPP_ZONE(ctx, worker, Zones::kOpsRopeAndMulBy);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -480,9 +478,9 @@ template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
float* HWY_RESTRICT out, float* HWY_RESTRICT out,
const size_t size, const size_t size,
hwy::Profiler& p, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom)); GCPP_ZONE(ctx, worker, Zones::kOpsAddFrom);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -503,10 +501,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
cluster_idx, [&](uint64_t token_idx, size_t worker) { cluster_idx, Callers::kOpsRMSNormBatched,
[&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
/*w_ofs=*/0, out.Row(token_idx), activations.Cols(), /*w_ofs=*/0, out.Row(token_idx), activations.Cols(),
ctx.profiler, worker); ctx, worker);
}); });
}); });
} }
@ -519,10 +518,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
Callers::kOpsRMSNormInplaceBatched,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
inout.Row(token_idx), inout.Cols(), inout.Row(token_idx), inout.Cols(), ctx,
ctx.profiler, worker); worker);
}); });
}); });
} }
@ -549,13 +549,14 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
ThreadingContext& ctx, ThreadingContext& ctx,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, ParallelFor(
[&](uint64_t token_idx, size_t worker) { ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
ctx.profiler, worker); AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
}); });
} }
// No profiler zone because this is short and frequently called.
template <typename XT> template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const size_t size) { const size_t size) {
@ -575,8 +576,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t worker) { const size_t size, ThreadingContext& ctx, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo)); GCPP_ZONE(ctx, worker, Zones::kOpsMulByConstTo);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -1121,10 +1122,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2. // TODO: support bf16 logits using Decompress2.
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
const size_t worker, const size_t worker,
float temperature = 1.0f) { float temperature = 1.0f) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax)); GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
HWY_DASSERT(logits.size() != 0); HWY_DASSERT(logits.size() != 0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -1256,8 +1257,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
} }
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
hwy::Profiler& p, const size_t worker) { ThreadingContext& ctx,
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap)); const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kOpsLogitsSoftCap);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -1277,9 +1279,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
// Calls LogitsSoftCap if cap != 0.0f. // Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, Logits logits, hwy::Profiler& p, const size_t worker) { const float cap, Logits logits, ThreadingContext& ctx,
const size_t worker) {
if (cap != 0.0f) { if (cap != 0.0f) {
LogitsSoftCap(cap, logits, p, worker); LogitsSoftCap(cap, logits, ctx, worker);
} }
} }
@ -1288,9 +1291,10 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
ThreadingContext& ctx, size_t cluster_idx = 0) { ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return; if (cap == 0.0f) return;
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
Callers::kOpsMaybeLogitsSoftCapBatched,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) { if (non_eos.Get(task)) {
LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker); LogitsSoftCap(cap, x.RowSpan(task), ctx, worker);
} }
}); });
} }
@ -1371,7 +1375,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k,
template <typename TAcceptToken> template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
Logits logits, size_t k, RngStream& gen, float temperature, Logits logits, size_t k, RngStream& gen, float temperature,
TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) { TAcceptToken& accept_token, ThreadingContext& ctx, size_t worker) {
// Softmax and sample top-K is equivalent to taking the top-K logits and // Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it // sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits. // avoids computing the softmax of all logits.
@ -1384,7 +1388,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
} }
const size_t mask = token_logits.size(); const size_t mask = token_logits.size();
Softmax(Logits(topk_logits.data(), mask), p, worker, temperature); Softmax(Logits(topk_logits.data(), mask), ctx, worker, temperature);
auto distribution = std::discrete_distribution<int>( auto distribution = std::discrete_distribution<int>(
std::begin(topk_logits), std::begin(topk_logits) + mask); std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen); int topk_sampled_index = distribution(gen);

View File

@ -58,6 +58,11 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
static ThreadingContext& Ctx() {
static ThreadingContext* ctx = new ThreadingContext(ThreadingArgs());
return *ctx;
}
static RngStream MakeRng() { static RngStream MakeRng() {
static AesCtrEngine engine(/*deterministic=*/true); static AesCtrEngine engine(/*deterministic=*/true);
static uint64_t stream = 0; static uint64_t stream = 0;
@ -133,8 +138,7 @@ class TestAddFrom {
} }
SimpleAddFrom(o, e, count); SimpleAddFrom(o, e, count);
InitProfilerZones(hwy::Profiler::Get()); AddFrom(o, x, count, Ctx(), /*worker=*/0);
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -182,7 +186,6 @@ class TestMulByConstAndAdd {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count); SimpleMulByConstAndAdd(constant, o, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count); MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -231,7 +234,6 @@ class TestMulByConst {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count); SimpleMulByConst(constant, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count); MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -278,8 +280,7 @@ struct TestMulByConstTo {
hwy::ConvertScalarTo<float>(constant)); hwy::ConvertScalarTo<float>(constant));
} }
InitProfilerZones(hwy::Profiler::Get()); MulByConstTo(constant, x, actual, count, Ctx(),
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
/*worker=*/0); /*worker=*/0);
hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET), hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET),
@ -315,8 +316,7 @@ class TestSoftmax {
} }
SimpleSoftmax(e, count); SimpleSoftmax(e, count);
InitProfilerZones(hwy::Profiler::Get()); Softmax(Logits(x, count), Ctx(), /*worker=*/0);
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
T sum = 0.0f; T sum = 0.0f;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
@ -440,9 +440,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
ThreadingArgs threading_args; ThreadingContext& ctx = Ctx();
ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler;
const size_t worker = 0; const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
@ -476,7 +474,7 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual); CopyMat(x, qactual);
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
worker); worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
@ -487,7 +485,7 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual); CopyMat(x, qactual);
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
} }
@ -498,10 +496,10 @@ void TestRopeAndMulBy() {
CopyMat(x, kactual2); CopyMat(x, kactual2);
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
worker); worker);
static_assert(kmul == 1.0f, ""); static_assert(kmul == 1.0f, "");
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
@ -557,8 +555,7 @@ struct TestRMSNorm {
} }
ScalarRMSNorm(vec, weight, expected, kSize); ScalarRMSNorm(vec, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get()); RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, Ctx(),
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0); /*worker=*/0);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
@ -593,8 +590,7 @@ struct TestRMSNormInplace {
} }
ScalarRMSNorm(expected, weight, expected, kSize); ScalarRMSNorm(expected, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get()); RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, Ctx(),
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0); /*worker=*/0);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
@ -715,15 +711,14 @@ void TestAllLayerNorm() {
} }
void TestSampleTopK() { void TestSampleTopK() {
hwy::Profiler& p = hwy::Profiler::Get(); ThreadingContext& ctx = Ctx();
InitProfilerZones(p);
const size_t worker = 0; const size_t worker = 0;
const size_t kSize = 52; const size_t kSize = 52;
std::vector<float> logits_vec(kSize); std::vector<float> logits_vec(kSize);
Logits logits(logits_vec.data(), kSize); Logits logits(logits_vec.data(), kSize);
// Create a vector going from -100 to -100+51=49 and take Softmax. // Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f); std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits, p, worker); Softmax(logits, ctx, worker);
RngStream rng = MakeRng(); RngStream rng = MakeRng();
float temperature = 1.0f; float temperature = 1.0f;
// SampleTopK<1> should return the argmax. // SampleTopK<1> should return the argmax.
@ -736,7 +731,7 @@ void TestSampleTopK() {
EXPECT_EQ(sample, 50); // Last even index. EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax. // Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f); std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits, p, worker); Softmax(logits, ctx, worker);
// Sample from the top 3, expect one of the top 3 even indices. // Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);

View File

@ -49,9 +49,10 @@ PinningPolicy::PinningPolicy(Tristate pin) {
static void MaybePin(const BoundedTopology& topology, size_t cluster_idx, static void MaybePin(const BoundedTopology& topology, size_t cluster_idx,
const BoundedTopology::Cluster& cluster, const BoundedTopology::Cluster& cluster,
PinningPolicy& pinning, hwy::ThreadPool& pool) { PinningPolicy& pinning, hwy::ThreadPool& pool) {
static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("MaybePin");
const std::vector<size_t> lps = cluster.LPVector(); const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool.NumWorkers() <= lps.size()); HWY_ASSERT(pool.NumWorkers() <= lps.size());
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), caller, [&](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task HWY_ASSERT(task == thread); // each worker has one task
char buf[16]; // Linux limitation char buf[16]; // Linux limitation
@ -141,17 +142,20 @@ NestedPools::NestedPools(const BoundedTopology& topology,
// Parallel so we also pin the calling worker in `all_clusters` to // Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`. // `cluster.lps`.
all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) { static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("NestedPools");
HWY_ASSERT(cluster_idx == thread); // each thread has one task all_clusters_->Run(
const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx); 0, num_clusters, caller, [&](size_t cluster_idx, size_t thread) {
clusters_[cluster_idx] = HWY_ASSERT(cluster_idx == thread); // each thread has one task
MakePool(allocator, workers_per_cluster[cluster_idx], const BoundedTopology::Cluster& tcluster =
hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_), topology.GetCluster(cluster_idx);
tcluster.Node()); clusters_[cluster_idx] = MakePool(
// Pin workers AND the calling thread from `all_clusters_`. allocator, workers_per_cluster[cluster_idx],
MaybePin(topology, cluster_idx, tcluster, pinning_, hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_),
*clusters_[cluster_idx]); tcluster.Node());
}); // Pin workers AND the calling thread from `all_clusters_`.
MaybePin(topology, cluster_idx, tcluster, pinning_,
*clusters_[cluster_idx]);
});
all_pinned_ = pinning_.AllPinned(&pin_string_); all_pinned_ = pinning_.AllPinned(&pin_string_);
} }

View File

@ -266,9 +266,9 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
// index to a range. // index to a range.
template <class Func> template <class Func>
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
const Func& func) { hwy::pool::Caller caller, const Func& func) {
const size_t num_tasks = get1.NumTasks(); const size_t num_tasks = get1.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
const IndexRange range1 = get1.Range(task); const IndexRange range1 = get1.Range(task);
func(range1, thread); func(range1, thread);
}); });
@ -282,11 +282,12 @@ void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
template <class Func> template <class Func>
void ParallelizeTwoRanges(const IndexRangePartition& get1, void ParallelizeTwoRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2, const IndexRangePartition& get2,
hwy::ThreadPool& pool, const Func& func) { hwy::ThreadPool& pool, hwy::pool::Caller caller,
const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks())); const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num_tasks = get1.NumTasks() * get2.NumTasks(); const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32)); HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task)); const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task)); const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
@ -298,37 +299,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
}); });
} }
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
const Func& func) {
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
func(task, thread);
});
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(
ranges, all_clusters,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools.Cluster(cluster_idx);
const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(),
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -78,31 +78,33 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
#endif #endif
} }
static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) { static void TunePools(hwy::PoolWaitMode wait_mode, ThreadingContext& ctx) {
hwy::ThreadPool& clusters = pools.AllClusters(); hwy::ThreadPool& clusters = ctx.pools.AllClusters();
TunePool(wait_mode, clusters); TunePool(wait_mode, clusters);
// Run in parallel because Turin CPUs have 16, and in real usage, we often // Run in parallel because Turin CPUs have 16, and in real usage, we often
// run all at the same time. // run all at the same time.
clusters.Run(0, clusters.NumWorkers(), clusters.Run(0, clusters.NumWorkers(),
ctx.pool_callers.Get(Callers::kTunePool),
[&](uint64_t cluster_idx, size_t /*thread*/) { [&](uint64_t cluster_idx, size_t /*thread*/) {
TunePool(wait_mode, pools.Cluster(cluster_idx)); TunePool(wait_mode, ctx.pools.Cluster(cluster_idx));
}); });
} }
ThreadingContext::ThreadingContext(const ThreadingArgs& args) ThreadingContext::ThreadingContext(const ThreadingArgs& args)
: profiler(hwy::Profiler::Get()), : profiler(hwy::Profiler::Get()),
profiler_zones(profiler),
pool_callers(),
topology(BoundedSlice(args.skip_packages, args.max_packages), topology(BoundedSlice(args.skip_packages, args.max_packages),
BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_clusters, args.max_clusters),
BoundedSlice(args.skip_lps, args.max_lps)), BoundedSlice(args.skip_lps, args.max_lps)),
cache_info(topology), cache_info(topology),
allocator(topology, cache_info, args.bind != Tristate::kFalse), allocator(topology, cache_info, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) { pools(topology, allocator, args.max_threads, args.pin) {
InitProfilerZones(profiler);
PROFILER_ZONE("Startup.ThreadingContext autotune"); PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePools(hwy::PoolWaitMode::kSpin, pools); TunePools(hwy::PoolWaitMode::kSpin, *this);
// kBlock is the default, hence set/tune it last. // kBlock is the default, hence set/tune it last.
TunePools(hwy::PoolWaitMode::kBlock, pools); TunePools(hwy::PoolWaitMode::kBlock, *this);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -28,6 +28,7 @@
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/threading.h" #include "util/threading.h"
#include "util/topology.h" #include "util/topology.h"
#include "util/zones.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -107,6 +108,9 @@ struct ThreadingContext {
// Singleton; pass around a reference to reduce overhead. // Singleton; pass around a reference to reduce overhead.
hwy::Profiler& profiler; hwy::Profiler& profiler;
ProfilerZones profiler_zones;
PoolCallers pool_callers;
// Detects topology, subject to limits imposed by user-specified `args`. // Detects topology, subject to limits imposed by user-specified `args`.
// For example, if `args.max_clusters` is 1, then `topology.NumClusters()` // For example, if `args.max_clusters` is 1, then `topology.NumClusters()`
// will be 1 regardless of the actual system topology. // will be 1 regardless of the actual system topology.
@ -122,6 +126,9 @@ struct ThreadingContext {
NestedPools pools; NestedPools pools;
}; };
#define GCPP_ZONE(ctx, global_idx, zone_enum) \
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
// Describes the strategy for distributing parallel work across cores. // Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t { enum class ParallelismStrategy : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker` // Execute using a single-threaded loop on the calling thread. The `worker`
@ -147,18 +154,53 @@ enum class ParallelismStrategy : uint8_t {
kHierarchical, kHierarchical,
}; };
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
Callers callers, const Func& func) {
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = ctx.pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
func(task, thread);
});
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(ranges, all_clusters, caller,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster =
ctx.pools.Cluster(cluster_idx);
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(), caller,
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
}
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for // number/type of workers determined by `parallelism`. `cluster_idx` is for
// `parallelism == kWithinCluster`, and should be 0 if unknown. // `parallelism == kWithinCluster`, and should be 0 if unknown.
template <class Func> template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, const Func& func) { ThreadingContext& ctx, size_t cluster_idx, Callers callers,
const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) { if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes. // If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone || HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
parallelism == ParallelismStrategy::kWithinCluster); parallelism == ParallelismStrategy::kWithinCluster);
} }
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
switch (parallelism) { switch (parallelism) {
case ParallelismStrategy::kNone: { case ParallelismStrategy::kNone: {
@ -171,7 +213,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
case ParallelismStrategy::kAcrossClusters: case ParallelismStrategy::kAcrossClusters:
return ctx.pools.AllClusters().Run( return ctx.pools.AllClusters().Run(
0, num_tasks, 0, num_tasks, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: { case ParallelismStrategy::kWithinCluster: {
@ -179,7 +221,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
// used for TLS indexing for example in profiler.h. // used for TLS indexing for example in profiler.h.
const size_t base = ctx.Worker(cluster_idx); const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(cluster_idx) return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, [&](uint64_t task, size_t worker) { .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, base + worker); func(task, base + worker);
}); });
} }
@ -191,19 +233,19 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) { if (num_clusters == 1) {
return ctx.pools.Cluster(cluster_idx) return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, .Run(0, num_tasks, caller,
[&](uint64_t task, size_t worker) { func(task, worker); }); [&](uint64_t task, size_t worker) { func(task, worker); });
} }
return ctx.pools.AllClusters().Run( return all_clusters.Run(0, num_tasks, caller,
0, num_tasks, [&](uint64_t task, size_t cluster_idx) { [&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx); const size_t worker = ctx.Worker(cluster_idx);
func(task, worker); func(task, worker);
}); });
} }
case ParallelismStrategy::kHierarchical: case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx.pools, func); return HierarchicalParallelFor(num_tasks, ctx, callers, func);
} }
} }

View File

@ -37,6 +37,8 @@ namespace {
using ::testing::ElementsAre; using ::testing::ElementsAre;
static const hwy::pool::Caller kCaller = hwy::ThreadPool::AddCaller("Test");
TEST(ThreadingTest, TestBoundedSlice) { TEST(ThreadingTest, TestBoundedSlice) {
const char* name = "test"; const char* name = "test";
// No args = no limit. // No args = no limit.
@ -205,7 +207,7 @@ TEST(ThreadingTest, TestParallelizeOneRange) {
const IndexRangePartition partition = StaticPartition(range, 2, 4); const IndexRangePartition partition = StaticPartition(range, 2, 4);
hwy::ThreadPool null_pool(0); hwy::ThreadPool null_pool(0);
size_t calls = 0; size_t calls = 0;
ParallelizeOneRange(partition, null_pool, ParallelizeOneRange(partition, null_pool, kCaller,
[&](const IndexRange& range, size_t) { [&](const IndexRange& range, size_t) {
if (++calls == 1) { if (++calls == 1) {
HWY_ASSERT(range.begin() == 0 && range.end() == 8); HWY_ASSERT(range.begin() == 0 && range.end() == 8);
@ -226,7 +228,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
{ {
size_t calls = 0; size_t calls = 0;
ParallelizeTwoRanges( ParallelizeTwoRanges(
partition1, partition2, null_pool, partition1, partition2, null_pool, kCaller,
[&](const IndexRange& range1, const IndexRange& range2, size_t) { [&](const IndexRange& range1, const IndexRange& range2, size_t) {
++calls; ++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
@ -240,7 +242,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
{ {
size_t calls = 0; size_t calls = 0;
ParallelizeTwoRanges( ParallelizeTwoRanges(
partition2, partition1, null_pool, partition2, partition1, null_pool, kCaller,
[&](const IndexRange& range2, const IndexRange& range1, size_t) { [&](const IndexRange& range2, const IndexRange& range1, size_t) {
++calls; ++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
@ -265,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
for (size_t reps = 0; reps < 1200; ++reps) { for (size_t reps = 0; reps < 1200; ++reps) {
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerThread] = base + thread;
}); });
hwy::PreventElision(outputs[base]); hwy::PreventElision(outputs[base]);
@ -305,18 +307,20 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
if (have_stop) { if (have_stop) {
for (size_t rep = 0; rep < max_reps; ++rep) { for (size_t rep = 0; rep < max_reps; ++rep) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), kCaller,
outputs[thread * kU64PerThread] = base + thread; [&](uint64_t task, size_t thread) {
}); outputs[thread * kU64PerThread] = base + thread;
});
const uint64_t t1 = hwy::timer::Stop(); const uint64_t t1 = hwy::timer::Stop();
times.push_back(t1 - t0); times.push_back(t1 - t0);
} }
} else { } else {
for (size_t rep = 0; rep < max_reps; ++rep) { for (size_t rep = 0; rep < max_reps; ++rep) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), kCaller,
outputs[thread * kU64PerThread] = base + thread; [&](uint64_t task, size_t thread) {
}); outputs[thread * kU64PerThread] = base + thread;
});
const uint64_t t1 = hwy::timer::Start(); const uint64_t t1 = hwy::timer::Start();
times.push_back(t1 - t0); times.push_back(t1 - t0);
} }

View File

@ -1,70 +1,201 @@
#include "util/zones.h" #include "util/zones.h"
#include <stddef.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
namespace {
#if PROFILER_ENABLED const char* ZoneName(Zones zone) {
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones); switch (zone) {
case Zones::kFlashAttentionFlashAttention:
static const char* kProfilerZoneNames[kNumZones] = { return "FlashAttention.FlashAttention";
// Keep in sync with Zones enum. case Zones::kFlashAttentionInclusive:
"Ops.RMSNormMul", return "FlashAttention.Inclusive";
"Ops.RMSNorm", case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
"Ops.RMSNormInplace", return "FlashAttention.RMSNormAndPositionalEncoding";
"Ops.Rope", case Zones::kFlashAttentionSingleFlashAttention:
"Ops.RopeAndMulBy", return "FlashAttention.SingleFlashAttention";
"Ops.AddFrom", case Zones::kFlashAttentionTileFlashAttention:
"Ops.MulByConst", return "FlashAttention.TileFlashAttention";
"Ops.MulByConstTo", case Zones::kFlashAttentionTileFlashAttention4:
"Ops.MulByConstAndAdd", return "FlashAttention.TileFlashAttention4";
"Ops.MulByConstAndAddTile", case Zones::kFlashAttentionTransposeQ:
"Ops.MulByConstAndAddTile4", return "FlashAttention.TransposeQ";
"Ops.MulByConstAndAddVector", case Zones::kGenActivation:
"Ops.Softmax", return "Gen.Activation";
"Ops.LogitsSoftCap", case Zones::kGenActivationFused:
"FlashAttention.TransposeQ", return "Gen.ActivationFused";
"FlashAttention.RMSNormAndPositionalEncoding", case Zones::kGenAttention:
"FlashAttention.SingleFlashAttention", return "Gen.Attention";
"FlashAttention.TileFlashAttention", case Zones::kGenAttentionComputeQKV:
"FlashAttention.TileFlashAttention4", return "Gen.Attention.ComputeQKV";
"FlashAttention.FlashAttention", case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
"Gen.Activation", return "Gen.Attention.DotSoftmaxWeightedSumInclusive";
"Gen.ActivationFused", case Zones::kGenAttentionDotSoftmaxWeightedSumPar:
"Gen.SampleTop1", return "Gen.Attention.DotSoftmaxWeightedSum.par";
"Gen.SampleTopK", case Zones::kGenAttentionQDotK:
"Gen.Attention.QDotK", return "Gen.Attention.QDotK";
"Gen.Attention.DotSoftmaxWeightedSum.par", case Zones::kGenAttentionSumHeads:
"Startup.Weights.ReadAllToBF16", return "Gen.Attention.SumHeads";
"Startup.Weights.ReadBatches", case Zones::kGenEmbed:
"MM.Dispatch", return "Gen.Embed";
"MM.MatMul", case Zones::kGenEmbeddingMatmul:
"MM.TwoMatMul", return "Gen.EmbeddingMatmul";
"MM.DecompressA", case Zones::kGenFFW:
"MM.NT", return "Gen.FFW";
"MM.NT_K", case Zones::kGenSampleTop1:
"MM.NT_MT", return "Gen.SampleTop1";
"MM.NT_MT_K", case Zones::kGenSampleTopK:
}; return "Gen.SampleTopK";
case Zones::kMMDecompressA:
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones]; return "MM.DecompressA";
#endif case Zones::kMMDispatch:
return "MM.Dispatch";
void InitProfilerZones(hwy::Profiler& profiler) { case Zones::kMMMatMul:
#if PROFILER_ENABLED return "MM.MatMul";
// Initialize the zone handles. This is done once at startup. case Zones::kMMNT_K:
for (size_t i = 0; i < kNumZones; ++i) { return "MM.NT_K";
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]); case Zones::kMMNT_MT_K:
return "MM.NT_MT_K";
case Zones::kMMNT_MT:
return "MM.NT_MT";
case Zones::kMMNT:
return "MM.NT";
case Zones::kMMTwoMatMul:
return "MM.TwoMatMul";
case Zones::kOpsAddFrom:
return "Ops.AddFrom";
case Zones::kOpsLogitsSoftCap:
return "Ops.LogitsSoftCap";
// case Zones::kOpsMulByConst: // removed due to overhead
// case Zones::kOpsMulByConstAndAdd: // removed due to overhead
// case Zones::kOpsMulByConstAndAddTile: // removed due to overhead
// case Zones::kOpsMulByConstAndAddTile4: // removed due to overhead
// case Zones::kOpsMulByConstAndAddVector: // removed due to overhead
case Zones::kOpsMulByConstTo:
return "Ops.MulByConstTo";
case Zones::kOpsRmsNorm:
return "Ops.RMSNorm";
case Zones::kOpsRmsNormInplace:
return "Ops.RMSNormInplace";
case Zones::kOpsRmsNormMul:
return "Ops.RMSNormMul";
case Zones::kOpsRope:
return "Ops.Rope";
case Zones::kOpsRopeAndMulBy:
return "Ops.RopeAndMulBy";
case Zones::kOpsSoftmax:
return "Ops.Softmax";
case Zones::kStartupWeightsReadAllToBF16:
return "Startup.Weights.ReadAllToBF16";
case Zones::kStartupWeightsReadBatches:
return "Startup.Weights.ReadBatches";
default:
HWY_ABORT("Invalid zone %d.", static_cast<int>(zone));
} }
#endif
} }
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) { hwy::ProfilerFlags ZoneFlags(Zones zone) {
#if PROFILER_ENABLED switch (zone) {
return profiler_zone_handles[static_cast<size_t>(zone)]; case Zones::kFlashAttentionInclusive:
#else case Zones::kGenAttention:
return hwy::profiler::ZoneHandle(); case Zones::kGenAttentionComputeQKV:
#endif case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
case Zones::kGenAttentionSumHeads:
case Zones::kGenEmbed:
case Zones::kGenEmbeddingMatmul:
case Zones::kGenFFW:
return hwy::ProfilerFlags::kInclusive;
default:
return hwy::ProfilerFlags::kDefault;
}
}
const char* CallerName(Callers caller) {
switch (caller) {
case Callers::kActivationBatched:
return "ActivationBatched";
case Callers::kAllocateAndBindAll:
return "AllocateAndBindAll";
case Callers::kAttComputeQKV:
return "Att.ComputeQKV";
case Callers::kAttDotSoftmaxWeightedSum:
return "Att.DotSoftmaxWeightedSum";
case Callers::kBlobWriter:
return "BlobWriter";
case Callers::kCompress:
return "Compress";
case Callers::kFixupWeights:
return "FixupWeights";
case Callers::kFlashAttention:
return "FlashAttention";
case Callers::kFlashRMSNormAndPositionalEncoding:
return "Flash.RMSNormAndPositionalEncoding";
case Callers::kFlashTransposeQ:
return "Flash.TransposeQ";
case Callers::kMMClusterForMC:
return "MM.ClusterForMC";
case Callers::kMMClusterForMCNC:
return "MM.ClusterForMCNC";
case Callers::kMMClusterForN:
return "MM.ClusterForN";
case Callers::kMMHierForMC:
return "MM.HierForMC";
case Callers::kMMHierForMCNC:
return "MM.HierForMCNC";
case Callers::kMMHierForN:
return "MM.HierForN";
case Callers::kOpsAddFromBatched:
return "Ops.AddFromBatched";
case Callers::kOpsMaybeLogitsSoftCapBatched:
return "Ops.MaybeLogitsSoftCapBatched";
case Callers::kOpsRMSNormBatched:
return "Ops.RMSNormBatched";
case Callers::kOpsRMSNormInplaceBatched:
return "Ops.RMSNormInplaceBatched";
case Callers::kReadAllToBF16:
return "ReadAllToBF16";
case Callers::kReadBatches:
return "ReadBatches";
case Callers::kSampleAndStream:
return "SampleAndStream";
case Callers::kTest: // only for unit tests.
return "Test-only!";
case Callers::kTunePool:
return "TunePool";
case Callers::kVitDotSoftmax1:
return "Vit.DotSoftmax1";
case Callers::kVitDotSoftmax2:
return "Vit.DotSoftmax2";
case Callers::kVitDotSoftmax3:
return "Vit.DotSoftmax3";
case Callers::kVitDotSoftmax4:
return "Vit.DotSoftmax4";
default:
HWY_ABORT("Invalid caller %d.", static_cast<int>(caller));
}
}
} // namespace
ProfilerZones::ProfilerZones(hwy::Profiler& profiler) {
for (size_t i = 0;; ++i) {
const Zones zone = static_cast<Zones>(i);
if (zone == Zones::kNumZones) break;
handles_[i] = profiler.AddZone(ZoneName(zone), ZoneFlags(zone));
}
}
PoolCallers::PoolCallers() {
for (size_t i = 0;; ++i) {
const Callers caller = static_cast<Callers>(i);
if (caller == Callers::kNumCallers) break;
callers_[i] = hwy::ThreadPool::AddCaller(CallerName(caller));
}
} }
} // namespace gcpp } // namespace gcpp

View File

@ -1,57 +1,123 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#include <stddef.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
// Zones for the profiler. // Zones for the profiler.
enum class Zones { enum class Zones { // Keep sorted
kOpsRmsNormMul, kFlashAttentionFlashAttention,
kOpsRmsNorm, kFlashAttentionInclusive,
kOpsRmsNormInplace,
kOpsRope,
kOpsRopeAndMulBy,
kOpsAddFrom,
kOpsMulByConst,
kOpsMulByConstTo,
kOpsMulByConstAndAdd,
kOpsMulByConstAndAddTile,
kOpsMulByConstAndAddTile4,
kOpsMulByConstAndAddVector,
kOpsSoftmax,
kOpsLogitsSoftCap,
kFlashAttentionTransposeQ,
kFlashAttentionRmsNormAndPositionalEncoding, kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention, kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention, kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention4, kFlashAttentionTileFlashAttention4,
kFlashAttentionFlashAttention, kFlashAttentionTransposeQ,
kGenActivation, kGenActivation,
kGenActivationFused, kGenActivationFused,
kGenAttention,
kGenAttentionComputeQKV,
kGenAttentionDotSoftmaxWeightedSumInclusive,
kGenAttentionDotSoftmaxWeightedSumPar,
kGenAttentionQDotK,
kGenAttentionSumHeads,
kGenEmbed,
kGenEmbeddingMatmul,
kGenFFW,
kGenSampleTop1, kGenSampleTop1,
kGenSampleTopK, kGenSampleTopK,
kGenAttentionQDotK, kMMDecompressA,
kGenAttentionDotSoftmaxWeightedSumPar,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kMMDispatch, kMMDispatch,
kMMMatMul, kMMMatMul,
kMMTwoMatMul,
kMMDecompressA,
kMMNT,
kMMNT_K, kMMNT_K,
kMMNT_MT,
kMMNT_MT_K, kMMNT_MT_K,
kNumZones kMMNT_MT,
kMMNT,
kMMTwoMatMul,
kOpsAddFrom,
kOpsLogitsSoftCap,
// kOpsMulByConst, // removed due to overhead
// kOpsMulByConstAndAdd, // removed due to overhead
// kOpsMulByConstAndAddTile, // removed due to overhead
// kOpsMulByConstAndAddTile4, // removed due to overhead
// kOpsMulByConstAndAddVector, // removed due to overhead
kOpsMulByConstTo,
kOpsRmsNorm,
kOpsRmsNormInplace,
kOpsRmsNormMul,
kOpsRope,
kOpsRopeAndMulBy,
kOpsSoftmax,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kNumZones // must be last
}; };
// Initializes the profiler zones. Must be called before any other profiler // Owned by ThreadingContext.
// functions. class ProfilerZones {
void InitProfilerZones(hwy::Profiler& profiler); public:
ProfilerZones(hwy::Profiler& profiler);
// Returns the zone handle for the given zone enum value. hwy::profiler::ZoneHandle Get(Zones zone) {
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone); HWY_DASSERT(zone != Zones::kNumZones);
return handles_[static_cast<size_t>(zone)];
}
private:
hwy::profiler::ZoneHandle handles_[static_cast<size_t>(Zones::kNumZones)];
};
enum class Callers { // Keep sorted
kActivationBatched,
kAllocateAndBindAll,
kAttComputeQKV,
kAttDotSoftmaxWeightedSum,
kBlobWriter,
kCompress,
kFixupWeights,
kFlashAttention,
kFlashRMSNormAndPositionalEncoding,
kFlashTransposeQ,
kMMClusterForMC,
kMMClusterForMCNC,
kMMClusterForN,
kMMHierForMC,
kMMHierForMCNC,
kMMHierForN,
kOpsAddFromBatched,
kOpsMaybeLogitsSoftCapBatched,
kOpsRMSNormBatched,
kOpsRMSNormInplaceBatched,
kReadAllToBF16,
kReadBatches,
kSampleAndStream,
kTest, // only for unit tests.
kTunePool,
kVitDotSoftmax1,
kVitDotSoftmax2,
kVitDotSoftmax3,
kVitDotSoftmax4,
kNumCallers // must be last
};
// Owned by ThreadingContext.
class PoolCallers {
public:
PoolCallers();
hwy::pool::Caller Get(Callers caller) {
HWY_DASSERT(caller != Callers::kNumCallers);
return callers_[static_cast<size_t>(caller)];
}
private:
hwy::pool::Caller callers_[static_cast<size_t>(Callers::kNumCallers)];
};
} // namespace gcpp } // namespace gcpp