mirror of https://github.com/google/gemma.cpp.git
Major cleanup of profiler zones, add Caller annotation for all pool.Run
Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones Add GCPP_ZONE helper Add Caller argument to pool.Run to enable new stats Remove most direct dependencies on ThreadPool, prefer ParallelFor PiperOrigin-RevId: 822934530
This commit is contained in:
parent
9e8ac7e2f0
commit
3ed403e287
|
|
@ -95,7 +95,6 @@ cc_library(
|
|||
":topology",
|
||||
# Placeholder for container detection, do not remove
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
|
|
@ -124,7 +123,9 @@ cc_library(
|
|||
srcs = ["util/zones.cc"],
|
||||
hdrs = ["util/zones.h"],
|
||||
deps = [
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -258,7 +259,6 @@ cc_library(
|
|||
"//io:fields",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -278,7 +278,6 @@ cc_library(
|
|||
"//io:blob_store",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -309,7 +308,6 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
":threading",
|
||||
":threading_context",
|
||||
|
|
@ -397,7 +395,6 @@ cc_library(
|
|||
"@highway//:hwy",
|
||||
"@highway//:math",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
|
@ -424,7 +421,6 @@ cc_test(
|
|||
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -507,7 +503,6 @@ cc_test(
|
|||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee",
|
||||
commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -120,9 +120,9 @@ cc_library(
|
|||
":compress",
|
||||
":distortion",
|
||||
"//:mat",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -180,6 +180,7 @@ cc_library(
|
|||
":sfp",
|
||||
"//:basics",
|
||||
"//:mat",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
|
|
@ -203,9 +204,9 @@ cc_test(
|
|||
":test_util",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@
|
|||
|
||||
#include "compression/compress.h" // IWYU pragma: export
|
||||
#include "compression/distortion.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#if COMPRESS_STATS
|
||||
#include <cmath> // lroundf
|
||||
|
|
@ -493,13 +493,13 @@ struct CompressTraits<NuqStream> {
|
|||
}
|
||||
};
|
||||
|
||||
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
|
||||
// which is useful for compressing sub-regions of an array.
|
||||
// DEPRECATED: Use the overload with ThreadingContext instead.
|
||||
template <typename Packed>
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, hwy::ThreadPool& pool) {
|
||||
const size_t packed_ofs, hwy::ThreadPool& pool,
|
||||
hwy::pool::Caller caller = hwy::pool::Caller()) {
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
work.tls.resize(pool.NumWorkers());
|
||||
if constexpr (COMPRESS_STATS) {
|
||||
|
|
@ -511,7 +511,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(0, num_batches,
|
||||
pool.Run(0, num_batches, caller,
|
||||
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
|
|
@ -530,6 +530,17 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
|
||||
// which is useful for compressing sub-regions of an array.
|
||||
template <typename Packed>
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, ThreadingContext& ctx) {
|
||||
Compress(raw, num, work, packed, packed_ofs, ctx.pools.Pool(),
|
||||
ctx.pool_callers.Get(Callers::kCompress));
|
||||
}
|
||||
|
||||
// Same as above, but without parallelization nor benchmarking.
|
||||
template <typename Packed>
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@
|
|||
#include "compression/compress.h"
|
||||
#include "compression/distortion.h"
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -42,6 +42,17 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
static ThreadingArgs SingleThreadArgs() {
|
||||
ThreadingArgs args;
|
||||
args.max_lps = 1;
|
||||
return args;
|
||||
}
|
||||
|
||||
static ThreadingContext& Ctx() {
|
||||
static ThreadingContext* ctx = new ThreadingContext(SingleThreadArgs());
|
||||
return *ctx;
|
||||
}
|
||||
|
||||
// Calls Compress and Decompress2 and verifies the distortion/error.
|
||||
template <typename Packed>
|
||||
struct TestDecompress2 {
|
||||
|
|
@ -49,7 +60,9 @@ struct TestDecompress2 {
|
|||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const size_t N = hn::Lanes(d);
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
ThreadingArgs args;
|
||||
args.max_lps = 1;
|
||||
ThreadingContext ctx(args);
|
||||
hwy::RandomState rng;
|
||||
|
||||
const size_t num = 2 * N;
|
||||
|
|
@ -68,7 +81,7 @@ struct TestDecompress2 {
|
|||
// Short inputs fail VerifyGaussian.
|
||||
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, ctx);
|
||||
hn::Vec<D> raw0, raw1;
|
||||
Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1);
|
||||
hn::Store(raw0, d, dec.get());
|
||||
|
|
@ -129,7 +142,6 @@ struct TestShortLengths {
|
|||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const size_t N = hn::Lanes(d);
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
hwy::RandomState rng;
|
||||
|
||||
for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) {
|
||||
|
|
@ -149,7 +161,7 @@ struct TestShortLengths {
|
|||
// Short inputs fail VerifyGaussian.
|
||||
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, Ctx());
|
||||
DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(),
|
||||
num);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@
|
|||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
|
|
@ -57,9 +56,6 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
template <typename Packed>
|
||||
void InsertT(const char* name, F32Span weights,
|
||||
const TensorInfo& tensor_info) {
|
||||
// TODO(janwas): 1D parallel-for.
|
||||
hwy::ThreadPool& pool = ctx_.pools.Pool();
|
||||
|
||||
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
|
||||
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
|
||||
// and depending on the input values may require rescaling. Scaling is
|
||||
|
|
@ -82,21 +78,20 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
// succeeds, but we only have 10 floats, not the full tensor.
|
||||
if (weights.size() == 10 && mat.Extents().Area() != 10) {
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
/*packed_ofs=*/0, ctx_);
|
||||
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
|
||||
return;
|
||||
}
|
||||
|
||||
HWY_ASSERT(weights.size() == mat.Extents().Area());
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
/*packed_ofs=*/0, ctx_);
|
||||
writer_.Add(name, mat.Packed(), mat.PackedBytes());
|
||||
}
|
||||
|
||||
public:
|
||||
SbsWriterImpl(const std::string& sbs_path)
|
||||
: ctx_(ThreadingArgs()),
|
||||
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
|
||||
: ctx_(ThreadingArgs()), writer_(gcpp::Path(sbs_path), ctx_) {}
|
||||
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) override {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
// IWYU pragma: end_exports
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "util/threading_context.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
|
|
@ -98,19 +98,20 @@ void ForeachActivationType3(D d) {
|
|||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
||||
const Allocator& allocator, MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
|
||||
ThreadingContext& ctx) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
ws.tls.resize(pool.NumWorkers());
|
||||
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, allocator, padding);
|
||||
ws.tls.resize(ctx.pools.MaxWorkers());
|
||||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
if ((r + c) & 1)
|
||||
f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
|
|
@ -126,19 +127,20 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
|||
// `f` swaps `r` and `c`.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
const Allocator& allocator,
|
||||
MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
ThreadingContext& ctx) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
ws.tls.resize(pool.NumWorkers());
|
||||
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, allocator, padding);
|
||||
ws.tls.resize(ctx.pools.MaxWorkers());
|
||||
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
if ((r + c) & 1)
|
||||
f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
|
|
|
|||
|
|
@ -83,8 +83,8 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(Logits logits, hwy::Profiler& p) {
|
||||
Softmax(logits, p, hwy::Profiler::GlobalIdx());
|
||||
void CallSoftmax(Logits logits, ThreadingContext& ctx) {
|
||||
Softmax(logits, ctx, hwy::Profiler::GlobalIdx());
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -109,7 +109,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
|
||||
size_t /*worker*/) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx);
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||
HWY_ASSERT(pos < prompt.size());
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -55,8 +55,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK));
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
|
|
@ -75,7 +75,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
ThreadingContext& ctx, const size_t worker,
|
||||
const size_t pos, const float mul) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
|
|
@ -88,10 +88,10 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
}
|
||||
// PostQKType::Rope
|
||||
if (post_qk == PostQKType::HalfRope) {
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, ctx, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, ctx, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -99,18 +99,16 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||
const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v,
|
||||
float* HWY_RESTRICT att_out,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
// TODO: 2..4x unroll
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
|
||||
worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
||||
|
|
@ -118,7 +116,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
} else {
|
||||
{
|
||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
|
||||
worker);
|
||||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
|
|
@ -134,7 +133,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
|
|
@ -144,23 +143,23 @@ void SingleDotSoftmaxWeightedSum(
|
|||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
||||
layer.layer_config.qkv_dim, p, worker);
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
const Logits logits(att, att_len);
|
||||
MaybeLogitsSoftCap(att_cap, logits, p, worker);
|
||||
Softmax(logits, p, worker, /*temperature=*/1.0f);
|
||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||
worker);
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
ctx, worker);
|
||||
}
|
||||
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
|
|
@ -174,12 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
static const auto root_zone =
|
||||
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
|
||||
hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
|
||||
const auto zone =
|
||||
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
|
@ -199,7 +193,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||
const size_t head = activations.div_heads.Remainder(task);
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
|
|
@ -231,16 +225,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, ctx.profiler,
|
||||
worker);
|
||||
layer, activations, att, att_out, ctx, worker);
|
||||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
// Full parallelism is helpful, kAcrossClusters is insufficient.
|
||||
HierarchicalParallelFor(
|
||||
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools,
|
||||
func);
|
||||
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx,
|
||||
Callers::kAttDotSoftmaxWeightedSum, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -256,9 +249,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
AttentionActivations& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
|
||||
Zones::kGenAttentionComputeQKV);
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||
|
|
@ -295,7 +287,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// tasks are very lightweight.
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||
/*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR {
|
||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||
[&](size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
|
|
@ -316,12 +309,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32,
|
||||
qkv_dim, env.ctx.profiler, worker);
|
||||
qkv_dim, env.ctx, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||
env.ctx.profiler, worker, pos, /*mul=*/1.0f);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
|
||||
worker, pos, /*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
|
|
@ -332,9 +325,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
MatMulEnv& env) {
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
|
|
@ -352,9 +343,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
|
||||
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace gcpp {
|
|||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
hwy::Profiler& p, size_t worker, size_t pos, \
|
||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||
float mul); \
|
||||
\
|
||||
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
|
||||
|
|
@ -41,7 +41,7 @@ namespace gcpp {
|
|||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
|
|
|
|||
|
|
@ -61,13 +61,12 @@ static constexpr size_t kNFx8HTileSize = 8;
|
|||
// possible consecutive elements have the same KV.
|
||||
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||
const size_t qbatch_size, ThreadingContext& ctx) {
|
||||
const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
|
||||
// Group floats by the number of floats in a cache line.
|
||||
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
|
||||
const size_t num_heads = q.Cols() / q_t.Rows();
|
||||
const size_t batch_size = q.Rows() / qbatch_size;
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ);
|
||||
for (size_t lane = 0; lane < kNF; ++lane) {
|
||||
size_t q_row = task * kNF + lane;
|
||||
if (q_row >= q_t.Rows()) break;
|
||||
|
|
@ -83,10 +82,10 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
}
|
||||
};
|
||||
{
|
||||
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||
// Better than kFlat.
|
||||
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
|
||||
/*cluster_idx=*/0, func);
|
||||
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -96,12 +95,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
ThreadingContext& ctx) {
|
||||
const auto zone =
|
||||
GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||
const float query_scale = activations.query_scale;
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||
size_t qi = div_qbatch.Remainder(task);
|
||||
size_t batch_idx = div_qbatch.Divide(task);
|
||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||
|
|
@ -115,18 +112,19 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
||||
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
||||
worker, pos, query_scale);
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
|
||||
pos, query_scale);
|
||||
}
|
||||
};
|
||||
{
|
||||
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||
// very lightweight.
|
||||
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
/*cluster_idx=*/0, func);
|
||||
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
|
||||
func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -158,10 +156,9 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker,
|
||||
GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||
|
|
@ -170,7 +167,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
}
|
||||
float d = 1.0f;
|
||||
// This is just a copy of the first token.
|
||||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
|
|
@ -276,9 +273,8 @@ void TileFlashAttention(
|
|||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker,
|
||||
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
||||
constexpr int kHTileSize = kNFx8HTileSize;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
|
|
@ -430,9 +426,8 @@ void TileFlashAttention4(
|
|||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker,
|
||||
GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -597,10 +592,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
static const auto root_zone = ctx.profiler.AddZone(
|
||||
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
|
||||
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
|
||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
||||
layer, activations, ctx);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
|
|
@ -653,7 +645,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
|
||||
// For each head/token/query, compute fused flash Q.K, softmax and weighted V.
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
|
||||
// Offsets into original Q for each row in the tile.
|
||||
uint32_t q_offsets[kMaxNF];
|
||||
// Offsets into att_out for each row in the tile.
|
||||
|
|
@ -741,13 +733,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx.profiler,
|
||||
worker);
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else if (kVTileSize == 4) {
|
||||
TileFlashAttention4(
|
||||
activations.q, q_offsets, k, start_positions[offset], last_pos,
|
||||
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx.profiler, worker);
|
||||
TileFlashAttention4(activations.q, q_offsets, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else {
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
|
|
@ -757,7 +748,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||
layer_idx, layer, activations,
|
||||
activations.att_out.Row(0) + out_offsets[offset],
|
||||
ctx.profiler, worker);
|
||||
ctx, worker);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -765,7 +756,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
{
|
||||
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
|
||||
HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention,
|
||||
func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ namespace gcpp {
|
|||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
|
||||
size_t worker); \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
|
|
|
|||
|
|
@ -47,9 +47,9 @@ namespace HWY_NAMESPACE {
|
|||
// For use by Vit even if !GEMMA_FUSED_FFN.
|
||||
template <typename T1, typename T2>
|
||||
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation));
|
||||
const T2* HWY_RESTRICT c2, const size_t count,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenActivation);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -73,11 +73,11 @@ void ActivationBatched(
|
|||
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||
using T = typename Mat::T;
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const T*>(nullptr), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
static_cast<const T*>(nullptr), c1.Cols(), ctx,
|
||||
worker);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -87,8 +87,8 @@ void ActivationBatched(
|
|||
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
||||
const IndexRange range_r,
|
||||
const IndexRange range_c, const StridedViewBF C2,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused));
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenActivationFused);
|
||||
|
||||
const size_t cols = range_c.Num();
|
||||
HWY_DASSERT(C2.Cols() == cols);
|
||||
|
|
@ -119,16 +119,16 @@ HWY_NOINLINE void ActivationBatched(
|
|||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
ctx, worker);
|
||||
});
|
||||
} else { // No multiplier
|
||||
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const typename Mat2::T*>(nullptr),
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
c1.Cols(), ctx, worker);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -153,9 +153,7 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
|||
|
||||
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
||||
|
|
@ -163,8 +161,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
#if GEMMA_FUSED_FFN
|
||||
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||
StridedViewBF C2, size_t worker) {
|
||||
Activation(layer_config.activation, C1, range_r, range_c, C2,
|
||||
env.ctx.profiler, worker);
|
||||
Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx,
|
||||
worker);
|
||||
};
|
||||
MMOptions options;
|
||||
options.SetFunc(fused);
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/basics.h" // PROFILER_ZONE3
|
||||
#include "util/basics.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -138,9 +138,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
|||
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||
const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
static const auto zone =
|
||||
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone);
|
||||
GCPP_ZONE(ctx, hwy::Profiler::GlobalIdx(), Zones::kGenEmbed);
|
||||
|
||||
// Image tokens just need to be copied.
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
|
|
@ -415,9 +413,7 @@ static void SampleAndStream(const ModelConfig& config,
|
|||
MaybeObserve(runtime_config, activations, qbatch, -1);
|
||||
|
||||
{
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
||||
GCPP_ZONE(env.ctx, /*worker=*/0, Zones::kGenEmbeddingMatmul);
|
||||
// Compute logits from last layer activations.
|
||||
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, env, activations.logits);
|
||||
|
|
@ -431,7 +427,8 @@ static void SampleAndStream(const ModelConfig& config,
|
|||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
||||
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
|
||||
/*cluster_idx=*/0, Callers::kSampleAndStream,
|
||||
[&](size_t qi, size_t worker) {
|
||||
if (!non_eos.Get(qi)) return;
|
||||
|
||||
// We streamed all prefill tokens, but pos is still one behind
|
||||
|
|
@ -469,8 +466,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
|||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
|
||||
HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker,
|
||||
GetProfilerZone(Zones::kGenSampleTop1));
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenSampleTop1);
|
||||
return Top1OfSoftmax(logits);
|
||||
};
|
||||
}
|
||||
|
|
@ -478,14 +474,13 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
|||
// General case: Softmax with top-k sampling.
|
||||
return [&](size_t qi, size_t pos, Logits logits,
|
||||
size_t worker) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker,
|
||||
GetProfilerZone(Zones::kGenSampleTopK));
|
||||
GCPP_ZONE(ctx, worker, Zones::kGenSampleTopK);
|
||||
// We want a different sequence for each batch element and position.
|
||||
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
|
||||
RngStream gen(engine, stream);
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
||||
runtime_config.accept_token, ctx.profiler, worker);
|
||||
return FusedSoftmaxAndSampleTopK(logits, runtime_config.top_k, gen,
|
||||
runtime_config.temperature,
|
||||
runtime_config.accept_token, ctx, worker);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -657,8 +652,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
|
||||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
|
||||
BlobWriter writer(weights_path, pools.Pool());
|
||||
void Gemma::Save(const Path& weights_path, ThreadingContext& ctx) const {
|
||||
BlobWriter writer(weights_path, ctx);
|
||||
const std::vector<uint32_t> serialized_mat_ptrs =
|
||||
weights_.AddTensorDataToWriter(writer);
|
||||
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class Gemma {
|
|||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const InferenceArgs& Inference() const { return inference_; }
|
||||
|
||||
void Save(const Path& weights_path, NestedPools& pools) const;
|
||||
void Save(const Path& weights_path, ThreadingContext& ctx) const;
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@
|
|||
// IWYU pragma: end_exports
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
34
gemma/vit.cc
34
gemma/vit.cc
|
|
@ -90,16 +90,19 @@ class VitAttention {
|
|||
ZeroInit(activations_.attention.att_out);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
pool_.Run(0, num_tokens_, caller1_,
|
||||
[&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed
|
||||
// working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
pool_.Run(
|
||||
0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t seq_idx = task;
|
||||
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
|
||||
head * 3 * qkv_dim + qkv_dim;
|
||||
|
|
@ -109,11 +112,12 @@ class VitAttention {
|
|||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
CallMatMul(Q, K, nullptr, env_, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
Softmax(C.RowSpan(task), env_.ctx.profiler, worker);
|
||||
});
|
||||
pool_.Run(0, num_tokens_, caller3_,
|
||||
[&](uint64_t task, size_t worker)
|
||||
HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); });
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
pool_.Run(
|
||||
0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
@ -136,7 +140,7 @@ class VitAttention {
|
|||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_, caller1_,
|
||||
[&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
|
|
@ -152,7 +156,7 @@ class VitAttention {
|
|||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker);
|
||||
Softmax(Logits(head_att, seq_len), env_.ctx, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
@ -185,7 +189,11 @@ class VitAttention {
|
|||
layer_(layer),
|
||||
layer_config_(layer.layer_config),
|
||||
env_(env),
|
||||
pool_(env_.ctx.pools.Pool(0)) {}
|
||||
pool_(env_.ctx.pools.Pool(0)),
|
||||
caller1_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax1)),
|
||||
caller2_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax2)),
|
||||
caller3_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax3)),
|
||||
caller4_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax4)) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
|
|
@ -204,6 +212,10 @@ class VitAttention {
|
|||
const LayerConfig& layer_config_;
|
||||
MatMulEnv& env_;
|
||||
hwy::ThreadPool& pool_;
|
||||
hwy::pool::Caller caller1_;
|
||||
hwy::pool::Caller caller2_;
|
||||
hwy::pool::Caller caller3_;
|
||||
hwy::pool::Caller caller4_;
|
||||
};
|
||||
|
||||
// Same as FFWNoVit, but with different layer members and no second
|
||||
|
|
@ -333,7 +345,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
|||
// Apply soft embedding norm before input projection.
|
||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||
activations.x.Row(0), vit_model_dim, env.ctx.profiler,
|
||||
activations.x.Row(0), vit_model_dim, env.ctx,
|
||||
hwy::Profiler::GlobalIdx());
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@
|
|||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -150,7 +149,7 @@ void LayerWeightsPtrs::SplitAttW1() {
|
|||
static void HWY_MAYBE_UNUSED InitAttWeightsI8(
|
||||
const LayerConfig& layer_config, MatPtrT<I8Stream>& attn_vec_einsum_w,
|
||||
MatPtrT<I8Stream>& att_weights, std::vector<MatOwner>& mat_owners,
|
||||
const Allocator& allocator) {
|
||||
ThreadingContext& ctx) {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8);
|
||||
|
||||
|
|
@ -160,7 +159,8 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
|
|||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked);
|
||||
mat_owners.back().AllocateFor(att_weights, ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
}
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
|
|
@ -188,10 +188,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8(
|
|||
}
|
||||
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
|
||||
work, att_weights.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
/*packed_ofs=*/0, ctx);
|
||||
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
|
@ -201,7 +200,7 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
|
|||
MatPtrT<I8Stream>& gating_einsum_w1,
|
||||
MatPtrT<I8Stream>& gating_einsum_w2,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
const Allocator& allocator) {
|
||||
ThreadingContext& ctx) {
|
||||
// Files have both or neither of w1 and w2.
|
||||
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||
// w is mutually exclusive with w1 and w2 in the file.
|
||||
|
|
@ -228,10 +227,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
|
|||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(gating_einsum_w1, allocator,
|
||||
mat_owners.back().AllocateFor(gating_einsum_w1, ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(gating_einsum_w2, allocator,
|
||||
mat_owners.back().AllocateFor(gating_einsum_w2, ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
}
|
||||
|
||||
|
|
@ -248,11 +247,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config,
|
|||
float* w2_tmp = w_tmp.get() + split_size;
|
||||
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0,
|
||||
pool);
|
||||
ctx);
|
||||
HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0,
|
||||
pool);
|
||||
ctx);
|
||||
|
||||
gating_einsum_w1.SetScale(1.0f);
|
||||
gating_einsum_w2.SetScale(1.0f);
|
||||
|
|
@ -265,7 +263,7 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
|
|||
MatPtrT<I8Stream>& qkv_einsum_w1,
|
||||
MatPtrT<I8Stream>& qkv_einsum_w2,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
const Allocator& allocator) {
|
||||
ThreadingContext& ctx) {
|
||||
// w is mutually exclusive with w1 in the file.
|
||||
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
||||
// Done if we already read split tensors.
|
||||
|
|
@ -291,10 +289,10 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
|
|||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(qkv_einsum_w1, allocator,
|
||||
mat_owners.back().AllocateFor(qkv_einsum_w1, ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
mat_owners.emplace_back();
|
||||
mat_owners.back().AllocateFor(qkv_einsum_w2, allocator,
|
||||
mat_owners.back().AllocateFor(qkv_einsum_w2, ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
}
|
||||
|
||||
|
|
@ -312,9 +310,8 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
|
|||
float* w2_tmp = w_tmp.get() + w1_size;
|
||||
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool);
|
||||
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool);
|
||||
HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, ctx);
|
||||
HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, ctx);
|
||||
|
||||
qkv_einsum_w1.SetScale(1.0f);
|
||||
qkv_einsum_w2.SetScale(1.0f);
|
||||
|
|
@ -326,16 +323,16 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config,
|
|||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
const Allocator& allocator) {
|
||||
ThreadingContext& ctx) {
|
||||
if (attn_vec_einsum_w.GetType() == Type::kI8) {
|
||||
MatPtrT<I8Stream> attn_vec_einsum_w_i8(attn_vec_einsum_w);
|
||||
MatPtrT<I8Stream> att_weights_i8(att_weights);
|
||||
InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8,
|
||||
mat_owners, allocator);
|
||||
mat_owners, ctx);
|
||||
attn_vec_einsum_w = attn_vec_einsum_w_i8;
|
||||
att_weights = att_weights_i8;
|
||||
} else {
|
||||
InitAttWeights(mat_owners, allocator);
|
||||
InitAttWeights(mat_owners, ctx.allocator);
|
||||
}
|
||||
|
||||
if (gating_einsum_w.GetType() == Type::kI8) {
|
||||
|
|
@ -343,7 +340,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
|||
MatPtrT<I8Stream> gating_einsum_w1_i8(gating_einsum_w1);
|
||||
MatPtrT<I8Stream> gating_einsum_w2_i8(gating_einsum_w2);
|
||||
SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8,
|
||||
gating_einsum_w2_i8, mat_owners, allocator);
|
||||
gating_einsum_w2_i8, mat_owners, ctx);
|
||||
gating_einsum_w = gating_einsum_w_i8;
|
||||
gating_einsum_w1 = gating_einsum_w1_i8;
|
||||
gating_einsum_w2 = gating_einsum_w2_i8;
|
||||
|
|
@ -356,7 +353,7 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
|||
MatPtrT<I8Stream> qkv_einsum_w1_i8(qkv_einsum_w1);
|
||||
MatPtrT<I8Stream> qkv_einsum_w2_i8(qkv_einsum_w2);
|
||||
SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8,
|
||||
qkv_einsum_w2_i8, mat_owners, allocator);
|
||||
qkv_einsum_w2_i8, mat_owners, ctx);
|
||||
qkv_einsum_w = qkv_einsum_w_i8;
|
||||
qkv_einsum_w1 = qkv_einsum_w1_i8;
|
||||
qkv_einsum_w2 = qkv_einsum_w2_i8;
|
||||
|
|
@ -367,7 +364,8 @@ void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
|||
|
||||
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
|
||||
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners) {
|
||||
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
||||
|
||||
|
|
@ -399,10 +397,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
|
|||
}
|
||||
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim,
|
||||
work, att_weights.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
/*packed_ofs=*/0, ctx);
|
||||
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
|
@ -435,13 +432,13 @@ void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
|||
ThreadingContext& ctx) {
|
||||
const size_t cluster_idx = 0;
|
||||
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
|
||||
[&](uint64_t layer, size_t /*worker*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
|
||||
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
|
||||
[&](uint64_t layer, size_t /*worker*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
|
||||
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
|
||||
VitLayer(layer)->Fixup(mat_owners, ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -529,8 +526,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
|||
owners.resize(start + tensors.size());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
ctx.pools.Pool().Run(
|
||||
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
|
||||
TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
|
|
@ -587,14 +585,13 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
|
||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
|
||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||
? ParallelismStrategy::kHierarchical
|
||||
: ParallelismStrategy::kFlat;
|
||||
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
|
|
@ -679,12 +676,11 @@ static std::vector<IOBatch> MakeBatches(
|
|||
static void ReadBatches(const BlobReader& reader,
|
||||
const std::vector<IOBatch>& batches,
|
||||
ThreadingContext& ctx) {
|
||||
const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
|
||||
// >5x speedup from parallel reads when cached.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical,
|
||||
batches.size(), ctx, /*cluster_idx=*/0,
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
|
||||
/*cluster_idx=*/0, Callers::kReadBatches,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);
|
||||
const IOBatch& batch = batches[task];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
|
|||
// Must be called after reading weights via `ForEachTensor`.
|
||||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, const Allocator& allocator);
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
||||
|
||||
private:
|
||||
// Copies att_weights from `attn_vec_einsum_w`.
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ cc_library(
|
|||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -108,7 +107,6 @@ cc_binary(
|
|||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
|||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
HWY_ASSERT(ranges.size() == blobs.size());
|
||||
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
|
||||
cluster_idx, [&](size_t i, size_t /*thread*/) {
|
||||
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||
reader.file().Read(ranges[i].offset, ranges[i].bytes,
|
||||
blobs[i].data());
|
||||
|
|
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
|
|||
const double t0 = hwy::platform::Now();
|
||||
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
|
||||
ctx.pools.NumClusters());
|
||||
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0,
|
||||
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
|
||||
[&](const size_t task, size_t cluster_idx) {
|
||||
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
|
||||
task ? blobs1 : blobs2, ctx, cluster_idx);
|
||||
|
|
@ -190,7 +190,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
|
||||
[&](size_t i, size_t /*thread*/) {
|
||||
Callers::kTest, [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
if (mismatches != 0) {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@
|
|||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -469,8 +468,8 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
|||
}
|
||||
}
|
||||
|
||||
BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool)
|
||||
: file_(OpenFileOrNull(filename, "w+")), pool_(pool) {
|
||||
BlobWriter::BlobWriter(const Path& filename, ThreadingContext& ctx)
|
||||
: file_(OpenFileOrNull(filename, "w+")), ctx_(ctx) {
|
||||
if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
|
||||
// Write a placeholder header to the beginning of the file. If append-only,
|
||||
// we will later write a footer, else we will update the header.
|
||||
|
|
@ -489,10 +488,13 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
|
|||
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
|
||||
static_cast<const uint8_t*>(data), writes);
|
||||
|
||||
hwy::ThreadPool null_pool(0);
|
||||
hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_;
|
||||
pool_or_serial.Run(
|
||||
0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const ParallelismStrategy strategy = file_->IsAppendOnly()
|
||||
? ParallelismStrategy::kNone
|
||||
: ParallelismStrategy::kFlat;
|
||||
ParallelFor(
|
||||
strategy, writes.size(), ctx_,
|
||||
/*cluster_idx=*/0, Callers::kBlobWriter,
|
||||
[this, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange& range = writes[i].range;
|
||||
if (!file_->Write(writes[i].data, range.bytes, range.offset)) {
|
||||
const std::string& key = StringFromKey(keys_[range.key_idx]);
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@
|
|||
|
||||
#include "io/io.h" // File, Path, MapPtr
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ class BlobReader {
|
|||
// does not make sense to call the methods concurrently.
|
||||
class BlobWriter {
|
||||
public:
|
||||
BlobWriter(const Path& filename, hwy::ThreadPool& pool);
|
||||
BlobWriter(const Path& filename, ThreadingContext& ctx);
|
||||
|
||||
// Writes the blob to disk with padding for alignment. Aborts on error.
|
||||
void Add(const std::string& key, const void* data, size_t bytes);
|
||||
|
|
@ -129,7 +129,7 @@ class BlobWriter {
|
|||
std::unique_ptr<File> file_;
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
std::vector<size_t> blob_sizes_;
|
||||
hwy::ThreadPool& pool_;
|
||||
ThreadingContext& ctx_;
|
||||
// Current offset in the file used for writing.
|
||||
int64_t curr_offset_ = 0;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ class BlobStoreTest : public testing::Test {};
|
|||
TEST(BlobStoreTest, TestReadWrite) {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
|
||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||
|
||||
|
|
@ -52,7 +51,7 @@ TEST(BlobStoreTest, TestReadWrite) {
|
|||
|
||||
const std::string keyA("0123456789abcdef"); // max 16 characters
|
||||
const std::string keyB("q");
|
||||
BlobWriter writer(path, pool);
|
||||
BlobWriter writer(path, ctx);
|
||||
writer.Add(keyA, "DATA", 5);
|
||||
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||
writer.Finalize();
|
||||
|
|
@ -96,7 +95,6 @@ TEST(BlobStoreTest, TestReadWrite) {
|
|||
TEST(BlobStoreTest, TestNumBlobs) {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
hwy::RandomState rng;
|
||||
|
||||
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
|
||||
|
|
@ -106,7 +104,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
HWY_ASSERT(fd > 0);
|
||||
const Path path(path_str);
|
||||
|
||||
BlobWriter writer(path, pool);
|
||||
BlobWriter writer(path, ctx);
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(num_blobs);
|
||||
std::vector<std::vector<uint8_t>> blobs;
|
||||
|
|
@ -130,8 +128,12 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
|
||||
BlobReader reader(path);
|
||||
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
||||
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str());
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
|
||||
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
|
||||
std::to_string(i).c_str());
|
||||
const BlobRange* range = reader.Find(keys[i]);
|
||||
HWY_ASSERT(range);
|
||||
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
|
||||
|
|
@ -145,7 +147,8 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
span.size() == 1 ||
|
||||
span[span.size() - 1] == static_cast<uint8_t>(i >> 8);
|
||||
if (!match1 || !match2) {
|
||||
HWY_ABORT("%s num_blobs %zu blob %zu offset %zu is corrupted.",
|
||||
HWY_ABORT(
|
||||
"%s num_blobs %zu blob %zu offset %zu is corrupted.",
|
||||
path_str, num_blobs, i, range->offset);
|
||||
}
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -44,6 +44,6 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools);
|
||||
env.GetGemma()->Save(args.output_weights, env.Env().ctx);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@
|
|||
#include "ops/matmul.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
|
@ -72,7 +71,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
|||
// M = A rows, K = A cols, N = C cols.
|
||||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
if (env.print_config || env.print_measurement) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
|
@ -91,15 +89,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
if (add) {
|
||||
add_storage = GenerateMat<float>(Extents2D(1, N), env.ctx.allocator,
|
||||
MatPadding::kPacked, pool);
|
||||
add_storage =
|
||||
GenerateMat<float>(Extents2D(1, N), MatPadding::kPacked, env.ctx);
|
||||
add_storage.SetScale(1.0f);
|
||||
}
|
||||
|
||||
MatStorageT<TA> a =
|
||||
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool);
|
||||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(
|
||||
B_extents, env.ctx.allocator, MatPadding::kOdd, pool);
|
||||
MatStorageT<TA> a = GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx);
|
||||
MatStorageT<TB> b_trans =
|
||||
GenerateTransposedMat<TB>(B_extents, MatPadding::kOdd, env.ctx);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@
|
|||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/stats.h"
|
||||
#include "hwy/timer.h"
|
||||
|
|
@ -922,9 +921,11 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
|||
(void)ScaleWeights(raw, num);
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool(0); // num is too small for parallelization
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.max_lps = 1; // num is too small for parallelization
|
||||
ThreadingContext ctx(threading_args);
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw, num, work, packed, packed_ofs, pool);
|
||||
Compress(raw, num, work, packed, packed_ofs, ctx);
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num);
|
||||
|
|
@ -1125,7 +1126,7 @@ void TestAllDot() {
|
|||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kWithinCluster, kReps, ctx, 0,
|
||||
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
|
||||
[&](size_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Row(thread);
|
||||
float* HWY_RESTRICT pb = b.Row(thread);
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ class MMDecompress {
|
|||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
const auto zone = GetProfilerZone(Zones::kMMDecompressA);
|
||||
const auto zone = env.ctx.profiler_zones.Get(Zones::kMMDecompressA);
|
||||
|
||||
const auto do_range =
|
||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||
|
|
@ -879,9 +879,8 @@ class MMLoops {
|
|||
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||
args.env.ctx.Worker(args.options.cluster_idx),
|
||||
GetProfilerZone(Zones::kMMDispatch));
|
||||
GCPP_ZONE(args.env.ctx, args.env.ctx.Worker(args.options.cluster_idx),
|
||||
Zones::kMMDispatch);
|
||||
|
||||
DispatchParallelism(
|
||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||
|
|
@ -904,7 +903,7 @@ class MMLoops {
|
|||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
const auto zone = GetProfilerZone(Zones::kMMNT);
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT);
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
|
|
@ -940,7 +939,7 @@ class MMLoops {
|
|||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
const auto zone = GetProfilerZone(Zones::kMMNT_K);
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K);
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
|
||||
|
|
@ -976,7 +975,7 @@ class MMLoops {
|
|||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
const auto zone = GetProfilerZone(Zones::kMMNT_MT);
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT);
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||
|
||||
|
|
@ -1010,7 +1009,7 @@ class MMLoops {
|
|||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
const auto zone = GetProfilerZone(Zones::kMMNT_MT_K);
|
||||
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT_K);
|
||||
|
||||
parallel.ForRangesMC_NC(
|
||||
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||
|
|
@ -1063,8 +1062,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||
const size_t cluster_idx = options.cluster_idx;
|
||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
|
||||
GetProfilerZone(Zones::kMMMatMul));
|
||||
GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMMatMul);
|
||||
|
||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||
|
||||
|
|
@ -1124,8 +1122,7 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
|||
MatPtrT<BF16>& C, MMOptions options) {
|
||||
const size_t cluster_idx = options.cluster_idx;
|
||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
|
||||
GetProfilerZone(Zones::kMMTwoMatMul));
|
||||
GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMTwoMatMul);
|
||||
|
||||
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
|
||||
|
||||
|
|
|
|||
26
ops/matmul.h
26
ops/matmul.h
|
|
@ -111,6 +111,7 @@ struct MMParallelWithinCluster {
|
|||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
ParallelizeOneRange(ranges_n, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForN),
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, base + worker);
|
||||
});
|
||||
|
|
@ -127,12 +128,14 @@ struct MMParallelWithinCluster {
|
|||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
ParallelizeOneRange(ranges_nc, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
|
||||
[&](const IndexRange& range_nc, size_t worker) {
|
||||
func(ranges_mc.Range(0), range_nc, base + worker);
|
||||
});
|
||||
} else {
|
||||
ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster,
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) { func(range_mc, range_nc, base + worker); });
|
||||
}
|
||||
|
|
@ -146,6 +149,7 @@ struct MMParallelWithinCluster {
|
|||
|
||||
cluster.Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
ctx.pool_callers.Get(Callers::kMMClusterForMC),
|
||||
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
|
||||
}
|
||||
};
|
||||
|
|
@ -159,6 +163,7 @@ struct MMParallelHierarchical {
|
|||
const Func& func) const {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
HWY_DASSERT(caller_cluster_idx == 0);
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
|
||||
|
||||
// Single cluster: parallel-for over static partition of `range_n`.
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
|
|
@ -169,7 +174,7 @@ struct MMParallelHierarchical {
|
|||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
return ParallelizeOneRange(
|
||||
ranges_n, cluster,
|
||||
ranges_n, cluster, caller,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, worker);
|
||||
});
|
||||
|
|
@ -179,7 +184,7 @@ struct MMParallelHierarchical {
|
|||
const IndexRangePartition ranges_n =
|
||||
StaticPartition(range_n, num_clusters, n_multiple);
|
||||
ParallelizeOneRange(
|
||||
ranges_n, all_clusters,
|
||||
ranges_n, all_clusters, caller,
|
||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
|
|
@ -187,7 +192,7 @@ struct MMParallelHierarchical {
|
|||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
ParallelizeOneRange(
|
||||
worker_ranges, cluster,
|
||||
worker_ranges, cluster, caller,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, cluster_base + worker);
|
||||
});
|
||||
|
|
@ -203,6 +208,8 @@ struct MMParallelHierarchical {
|
|||
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
|
||||
const Func& func) const {
|
||||
HWY_DASSERT(caller_cluster_idx == 0);
|
||||
const hwy::pool::Caller caller =
|
||||
ctx.pool_callers.Get(Callers::kMMHierForMCNC);
|
||||
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
|
|
@ -215,12 +222,13 @@ struct MMParallelHierarchical {
|
|||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
return ParallelizeOneRange(
|
||||
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) {
|
||||
ranges_nc, cluster, caller,
|
||||
[&](const IndexRange& range_nc, size_t worker) {
|
||||
func(ranges_mc.Range(0), range_nc, worker);
|
||||
});
|
||||
} else {
|
||||
return ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster,
|
||||
ranges_mc, ranges_nc, cluster, caller,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) { func(range_mc, range_nc, worker); });
|
||||
}
|
||||
|
|
@ -229,11 +237,11 @@ struct MMParallelHierarchical {
|
|||
// Multiple clusters: N across clusters (both are usually the larger), and
|
||||
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
|
||||
ParallelizeOneRange(
|
||||
ranges_nc, all_clusters,
|
||||
ranges_nc, all_clusters, caller,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
|
||||
ParallelizeOneRange(ranges_mc, cluster,
|
||||
ParallelizeOneRange(ranges_mc, cluster, caller,
|
||||
[&](const IndexRange& range_mc, size_t worker) {
|
||||
func(range_mc, range_nc, cluster_base + worker);
|
||||
});
|
||||
|
|
@ -244,7 +252,7 @@ struct MMParallelHierarchical {
|
|||
template <class Func>
|
||||
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t caller_cluster_idx, const Func& func) const {
|
||||
HierarchicalParallelFor(range_mc.Num(), ctx.pools,
|
||||
HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC,
|
||||
[&](size_t task, size_t worker) {
|
||||
func(range_mc.begin() + task, worker);
|
||||
});
|
||||
|
|
@ -811,7 +819,7 @@ class MMZone {
|
|||
|
||||
private:
|
||||
uint64_t data_ = 0;
|
||||
uint64_t data2_ = 0;
|
||||
HWY_MEMBER_VAR_MAYBE_UNUSED uint64_t data2_ = 0;
|
||||
};
|
||||
#else
|
||||
struct MMZone {
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
|
|||
const IndexRangePartition get_col_c =
|
||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||
ParallelizeOneRange(
|
||||
get_col_c, all_clusters,
|
||||
get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest),
|
||||
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
|
||||
for (size_t r : all_rows_c) {
|
||||
TC* HWY_RESTRICT C_row = C.Row(r);
|
||||
|
|
@ -221,7 +221,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
|||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env, int line) {
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||
TypeName<TC>());
|
||||
|
|
@ -233,11 +232,10 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
MatStorageT<TA> A(
|
||||
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool));
|
||||
MatStorageT<TA> A(GenerateMat<TA>(A_extents, MatPadding::kOdd, env.ctx));
|
||||
// Must be packed because we call Span() on it.
|
||||
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, env.ctx.allocator,
|
||||
MatPadding::kPacked, pool));
|
||||
MatStorageT<TB> BT(
|
||||
GenerateTransposedMat<TB>(B_extents, MatPadding::kPacked, env.ctx));
|
||||
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
|
||||
MatPadding::kOdd);
|
||||
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||
|
|
@ -246,8 +244,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
|
||||
MatPadding::kPacked, pool)
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), MatPadding::kPacked,
|
||||
env.ctx)
|
||||
: MatStorageT<float>("add", Extents2D(), env.ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
add_storage.SetScale(1.0f);
|
||||
|
|
|
|||
|
|
@ -205,9 +205,9 @@ namespace detail {
|
|||
|
||||
// Shared by RMSNorm and RMSNormInplace.
|
||||
template <typename VT>
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul));
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormMul);
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||
|
|
@ -218,19 +218,17 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
|||
} // namespace detail
|
||||
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||
const WT* HWY_RESTRICT weight,
|
||||
const size_t w_ofs,
|
||||
OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p,
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, const size_t w_ofs,
|
||||
OT* HWY_RESTRICT out, const size_t size, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm));
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNorm);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker));
|
||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, ctx, worker));
|
||||
const VF* HWY_RESTRICT pmul = &mul;
|
||||
|
||||
Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs,
|
||||
|
|
@ -245,13 +243,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
|||
template <typename WT, typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
|
||||
const size_t size, ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormInplace);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker));
|
||||
const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, ctx, worker));
|
||||
const VF* HWY_RESTRICT pmul = &mul;
|
||||
|
||||
Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs,
|
||||
|
|
@ -359,9 +357,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope));
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsRope);
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -418,9 +416,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy));
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsRopeAndMulBy);
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -480,9 +478,9 @@ template <typename XT>
|
|||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom));
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsAddFrom);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -503,10 +501,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||
cluster_idx, Callers::kOpsRMSNormBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||
/*w_ofs=*/0, out.Row(token_idx), activations.Cols(),
|
||||
ctx.profiler, worker);
|
||||
ctx, worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -519,10 +518,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
|||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsRMSNormInplaceBatched,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
||||
inout.Row(token_idx), inout.Cols(),
|
||||
ctx.profiler, worker);
|
||||
inout.Row(token_idx), inout.Cols(), ctx,
|
||||
worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -549,13 +549,14 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
|||
ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||
ctx.profiler, worker);
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
|
||||
});
|
||||
}
|
||||
|
||||
// No profiler zone because this is short and frequently called.
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||
const size_t size) {
|
||||
|
|
@ -575,8 +576,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo));
|
||||
const size_t size, ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsMulByConstTo);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -1121,10 +1122,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
// TODO: support bf16 logits using Decompress2.
|
||||
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
||||
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax));
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
|
||||
HWY_DASSERT(logits.size() != 0);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -1256,8 +1257,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
|
|||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap));
|
||||
ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kOpsLogitsSoftCap);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -1277,9 +1279,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
|
|||
|
||||
// Calls LogitsSoftCap if cap != 0.0f.
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||
const float cap, Logits logits, hwy::Profiler& p, const size_t worker) {
|
||||
const float cap, Logits logits, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
if (cap != 0.0f) {
|
||||
LogitsSoftCap(cap, logits, p, worker);
|
||||
LogitsSoftCap(cap, logits, ctx, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1288,9 +1291,10 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
|||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||
if (cap == 0.0f) return;
|
||||
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||
Callers::kOpsMaybeLogitsSoftCapBatched,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker);
|
||||
LogitsSoftCap(cap, x.RowSpan(task), ctx, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -1371,7 +1375,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k,
|
|||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
Logits logits, size_t k, RngStream& gen, float temperature,
|
||||
TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) {
|
||||
TAcceptToken& accept_token, ThreadingContext& ctx, size_t worker) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
|
|
@ -1384,7 +1388,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
}
|
||||
|
||||
const size_t mask = token_logits.size();
|
||||
Softmax(Logits(topk_logits.data(), mask), p, worker, temperature);
|
||||
Softmax(Logits(topk_logits.data(), mask), ctx, worker, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(
|
||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
static ThreadingContext& Ctx() {
|
||||
static ThreadingContext* ctx = new ThreadingContext(ThreadingArgs());
|
||||
return *ctx;
|
||||
}
|
||||
|
||||
static RngStream MakeRng() {
|
||||
static AesCtrEngine engine(/*deterministic=*/true);
|
||||
static uint64_t stream = 0;
|
||||
|
|
@ -133,8 +138,7 @@ class TestAddFrom {
|
|||
}
|
||||
|
||||
SimpleAddFrom(o, e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
AddFrom(o, x, count, Ctx(), /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -182,7 +186,6 @@ class TestMulByConstAndAdd {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
MulByConstAndAdd(constant, o, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
|
|
@ -231,7 +234,6 @@ class TestMulByConst {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConst(constant, e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
MulByConst(constant, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
|
|
@ -278,8 +280,7 @@ struct TestMulByConstTo {
|
|||
hwy::ConvertScalarTo<float>(constant));
|
||||
}
|
||||
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
|
||||
MulByConstTo(constant, x, actual, count, Ctx(),
|
||||
/*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET),
|
||||
|
|
@ -315,8 +316,7 @@ class TestSoftmax {
|
|||
}
|
||||
|
||||
SimpleSoftmax(e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
|
||||
Softmax(Logits(x, count), Ctx(), /*worker=*/0);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -440,9 +440,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
}
|
||||
|
||||
void TestRopeAndMulBy() {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::Profiler& p = ctx.profiler;
|
||||
ThreadingContext& ctx = Ctx();
|
||||
const size_t worker = 0;
|
||||
|
||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
|
|
@ -476,7 +474,7 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
|
||||
worker);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
|
|
@ -487,7 +485,7 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
|
@ -498,10 +496,10 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, kactual2);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx,
|
||||
worker);
|
||||
static_assert(kmul == 1.0f, "");
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker);
|
||||
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||
|
|
@ -557,8 +555,7 @@ struct TestRMSNorm {
|
|||
}
|
||||
|
||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
|
||||
RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, Ctx(),
|
||||
/*worker=*/0);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
|
|
@ -593,8 +590,7 @@ struct TestRMSNormInplace {
|
|||
}
|
||||
|
||||
ScalarRMSNorm(expected, weight, expected, kSize);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(),
|
||||
RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, Ctx(),
|
||||
/*worker=*/0);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
|
|
@ -715,15 +711,14 @@ void TestAllLayerNorm() {
|
|||
}
|
||||
|
||||
void TestSampleTopK() {
|
||||
hwy::Profiler& p = hwy::Profiler::Get();
|
||||
InitProfilerZones(p);
|
||||
ThreadingContext& ctx = Ctx();
|
||||
const size_t worker = 0;
|
||||
const size_t kSize = 52;
|
||||
std::vector<float> logits_vec(kSize);
|
||||
Logits logits(logits_vec.data(), kSize);
|
||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
Softmax(logits, p, worker);
|
||||
Softmax(logits, ctx, worker);
|
||||
RngStream rng = MakeRng();
|
||||
float temperature = 1.0f;
|
||||
// SampleTopK<1> should return the argmax.
|
||||
|
|
@ -736,7 +731,7 @@ void TestSampleTopK() {
|
|||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits, p, worker);
|
||||
Softmax(logits, ctx, worker);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token);
|
||||
|
|
|
|||
|
|
@ -49,9 +49,10 @@ PinningPolicy::PinningPolicy(Tristate pin) {
|
|||
static void MaybePin(const BoundedTopology& topology, size_t cluster_idx,
|
||||
const BoundedTopology::Cluster& cluster,
|
||||
PinningPolicy& pinning, hwy::ThreadPool& pool) {
|
||||
static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("MaybePin");
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(pool.NumWorkers() <= lps.size());
|
||||
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, pool.NumWorkers(), caller, [&](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
|
||||
char buf[16]; // Linux limitation
|
||||
|
|
@ -141,11 +142,14 @@ NestedPools::NestedPools(const BoundedTopology& topology,
|
|||
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) {
|
||||
static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("NestedPools");
|
||||
all_clusters_->Run(
|
||||
0, num_clusters, caller, [&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(allocator, workers_per_cluster[cluster_idx],
|
||||
const BoundedTopology::Cluster& tcluster =
|
||||
topology.GetCluster(cluster_idx);
|
||||
clusters_[cluster_idx] = MakePool(
|
||||
allocator, workers_per_cluster[cluster_idx],
|
||||
hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_),
|
||||
tcluster.Node());
|
||||
// Pin workers AND the calling thread from `all_clusters_`.
|
||||
|
|
|
|||
|
|
@ -266,9 +266,9 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
|
|||
// index to a range.
|
||||
template <class Func>
|
||||
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
|
||||
const Func& func) {
|
||||
hwy::pool::Caller caller, const Func& func) {
|
||||
const size_t num_tasks = get1.NumTasks();
|
||||
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
const IndexRange range1 = get1.Range(task);
|
||||
func(range1, thread);
|
||||
});
|
||||
|
|
@ -282,11 +282,12 @@ void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
|
|||
template <class Func>
|
||||
void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||
const IndexRangePartition& get2,
|
||||
hwy::ThreadPool& pool, const Func& func) {
|
||||
hwy::ThreadPool& pool, hwy::pool::Caller caller,
|
||||
const Func& func) {
|
||||
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
|
||||
|
||||
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
|
||||
pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
HWY_DASSERT(task < (uint64_t{1} << 32));
|
||||
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
|
||||
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
|
||||
|
|
@ -298,37 +299,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
|||
});
|
||||
}
|
||||
|
||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||
// over clusters of ONE package, then within each cluster.
|
||||
template <class Func>
|
||||
void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
|
||||
const Func& func) {
|
||||
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
||||
// there is only one cluster.
|
||||
hwy::ThreadPool& all_clusters = pools.AllClusters();
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
hwy::ThreadPool& cluster = pools.Cluster(0);
|
||||
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
||||
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||
func(task, thread);
|
||||
});
|
||||
}
|
||||
|
||||
// Assign each cluster a sub-range.
|
||||
const IndexRangePartition ranges =
|
||||
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
|
||||
ParallelizeOneRange(
|
||||
ranges, all_clusters,
|
||||
[&](const IndexRange& range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = pools.Cluster(cluster_idx);
|
||||
const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster();
|
||||
cluster.Run(range.begin(), range.end(),
|
||||
[&](uint64_t task, size_t thread) {
|
||||
func(task, cluster_base + thread);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
|
|
|||
|
|
@ -78,31 +78,33 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
|||
#endif
|
||||
}
|
||||
|
||||
static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) {
|
||||
hwy::ThreadPool& clusters = pools.AllClusters();
|
||||
static void TunePools(hwy::PoolWaitMode wait_mode, ThreadingContext& ctx) {
|
||||
hwy::ThreadPool& clusters = ctx.pools.AllClusters();
|
||||
TunePool(wait_mode, clusters);
|
||||
|
||||
// Run in parallel because Turin CPUs have 16, and in real usage, we often
|
||||
// run all at the same time.
|
||||
clusters.Run(0, clusters.NumWorkers(),
|
||||
ctx.pool_callers.Get(Callers::kTunePool),
|
||||
[&](uint64_t cluster_idx, size_t /*thread*/) {
|
||||
TunePool(wait_mode, pools.Cluster(cluster_idx));
|
||||
TunePool(wait_mode, ctx.pools.Cluster(cluster_idx));
|
||||
});
|
||||
}
|
||||
|
||||
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
||||
: profiler(hwy::Profiler::Get()),
|
||||
profiler_zones(profiler),
|
||||
pool_callers(),
|
||||
topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||
BoundedSlice(args.skip_clusters, args.max_clusters),
|
||||
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||
cache_info(topology),
|
||||
allocator(topology, cache_info, args.bind != Tristate::kFalse),
|
||||
pools(topology, allocator, args.max_threads, args.pin) {
|
||||
InitProfilerZones(profiler);
|
||||
PROFILER_ZONE("Startup.ThreadingContext autotune");
|
||||
TunePools(hwy::PoolWaitMode::kSpin, pools);
|
||||
TunePools(hwy::PoolWaitMode::kSpin, *this);
|
||||
// kBlock is the default, hence set/tune it last.
|
||||
TunePools(hwy::PoolWaitMode::kBlock, pools);
|
||||
TunePools(hwy::PoolWaitMode::kBlock, *this);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading.h"
|
||||
#include "util/topology.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/profiler.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
|
|
@ -107,6 +108,9 @@ struct ThreadingContext {
|
|||
// Singleton; pass around a reference to reduce overhead.
|
||||
hwy::Profiler& profiler;
|
||||
|
||||
ProfilerZones profiler_zones;
|
||||
PoolCallers pool_callers;
|
||||
|
||||
// Detects topology, subject to limits imposed by user-specified `args`.
|
||||
// For example, if `args.max_clusters` is 1, then `topology.NumClusters()`
|
||||
// will be 1 regardless of the actual system topology.
|
||||
|
|
@ -122,6 +126,9 @@ struct ThreadingContext {
|
|||
NestedPools pools;
|
||||
};
|
||||
|
||||
#define GCPP_ZONE(ctx, global_idx, zone_enum) \
|
||||
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
|
||||
|
||||
// Describes the strategy for distributing parallel work across cores.
|
||||
enum class ParallelismStrategy : uint8_t {
|
||||
// Execute using a single-threaded loop on the calling thread. The `worker`
|
||||
|
|
@ -147,18 +154,53 @@ enum class ParallelismStrategy : uint8_t {
|
|||
kHierarchical,
|
||||
};
|
||||
|
||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||
// over clusters of ONE package, then within each cluster.
|
||||
template <class Func>
|
||||
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
|
||||
Callers callers, const Func& func) {
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
||||
// there is only one cluster.
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(0);
|
||||
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
||||
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
|
||||
func(task, thread);
|
||||
});
|
||||
}
|
||||
|
||||
// Assign each cluster a sub-range.
|
||||
const IndexRangePartition ranges =
|
||||
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
|
||||
ParallelizeOneRange(ranges, all_clusters, caller,
|
||||
[&](const IndexRange& range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster =
|
||||
ctx.pools.Cluster(cluster_idx);
|
||||
const size_t cluster_base =
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
cluster.Run(range.begin(), range.end(), caller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
func(task, cluster_base + thread);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
||||
// number/type of workers determined by `parallelism`. `cluster_idx` is for
|
||||
// `parallelism == kWithinCluster`, and should be 0 if unknown.
|
||||
template <class Func>
|
||||
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||
ThreadingContext& ctx, size_t cluster_idx, const Func& func) {
|
||||
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
|
||||
if (cluster_idx != 0) {
|
||||
// If already running across clusters, only use within-cluster modes.
|
||||
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
|
||||
parallelism == ParallelismStrategy::kWithinCluster);
|
||||
}
|
||||
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
|
||||
|
||||
switch (parallelism) {
|
||||
case ParallelismStrategy::kNone: {
|
||||
|
|
@ -171,7 +213,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
|
||||
case ParallelismStrategy::kAcrossClusters:
|
||||
return ctx.pools.AllClusters().Run(
|
||||
0, num_tasks,
|
||||
0, num_tasks, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||
|
||||
case ParallelismStrategy::kWithinCluster: {
|
||||
|
|
@ -179,7 +221,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
// used for TLS indexing for example in profiler.h.
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
return ctx.pools.Cluster(cluster_idx)
|
||||
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
|
||||
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
|
||||
func(task, base + worker);
|
||||
});
|
||||
}
|
||||
|
|
@ -191,19 +233,19 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
|||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
if (num_clusters == 1) {
|
||||
return ctx.pools.Cluster(cluster_idx)
|
||||
.Run(0, num_tasks,
|
||||
.Run(0, num_tasks, caller,
|
||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||
}
|
||||
|
||||
return ctx.pools.AllClusters().Run(
|
||||
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
|
||||
return all_clusters.Run(0, num_tasks, caller,
|
||||
[&](uint64_t task, size_t cluster_idx) {
|
||||
const size_t worker = ctx.Worker(cluster_idx);
|
||||
func(task, worker);
|
||||
});
|
||||
}
|
||||
|
||||
case ParallelismStrategy::kHierarchical:
|
||||
return HierarchicalParallelFor(num_tasks, ctx.pools, func);
|
||||
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ namespace {
|
|||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
static const hwy::pool::Caller kCaller = hwy::ThreadPool::AddCaller("Test");
|
||||
|
||||
TEST(ThreadingTest, TestBoundedSlice) {
|
||||
const char* name = "test";
|
||||
// No args = no limit.
|
||||
|
|
@ -205,7 +207,7 @@ TEST(ThreadingTest, TestParallelizeOneRange) {
|
|||
const IndexRangePartition partition = StaticPartition(range, 2, 4);
|
||||
hwy::ThreadPool null_pool(0);
|
||||
size_t calls = 0;
|
||||
ParallelizeOneRange(partition, null_pool,
|
||||
ParallelizeOneRange(partition, null_pool, kCaller,
|
||||
[&](const IndexRange& range, size_t) {
|
||||
if (++calls == 1) {
|
||||
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
|
||||
|
|
@ -226,7 +228,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
|
|||
{
|
||||
size_t calls = 0;
|
||||
ParallelizeTwoRanges(
|
||||
partition1, partition2, null_pool,
|
||||
partition1, partition2, null_pool, kCaller,
|
||||
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
|
||||
++calls;
|
||||
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||
|
|
@ -240,7 +242,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) {
|
|||
{
|
||||
size_t calls = 0;
|
||||
ParallelizeTwoRanges(
|
||||
partition2, partition1, null_pool,
|
||||
partition2, partition1, null_pool, kCaller,
|
||||
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
|
||||
++calls;
|
||||
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
|
||||
|
|
@ -265,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
|
||||
const double t0 = hwy::platform::Now();
|
||||
for (size_t reps = 0; reps < 1200; ++reps) {
|
||||
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
});
|
||||
hwy::PreventElision(outputs[base]);
|
||||
|
|
@ -305,7 +307,8 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
if (have_stop) {
|
||||
for (size_t rep = 0; rep < max_reps; ++rep) {
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
});
|
||||
const uint64_t t1 = hwy::timer::Stop();
|
||||
|
|
@ -314,7 +317,8 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
|||
} else {
|
||||
for (size_t rep = 0; rep < max_reps; ++rep) {
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) {
|
||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
outputs[thread * kU64PerThread] = base + thread;
|
||||
});
|
||||
const uint64_t t1 = hwy::timer::Start();
|
||||
|
|
|
|||
247
util/zones.cc
247
util/zones.cc
|
|
@ -1,70 +1,201 @@
|
|||
#include "util/zones.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
#if PROFILER_ENABLED
|
||||
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones);
|
||||
|
||||
static const char* kProfilerZoneNames[kNumZones] = {
|
||||
// Keep in sync with Zones enum.
|
||||
"Ops.RMSNormMul",
|
||||
"Ops.RMSNorm",
|
||||
"Ops.RMSNormInplace",
|
||||
"Ops.Rope",
|
||||
"Ops.RopeAndMulBy",
|
||||
"Ops.AddFrom",
|
||||
"Ops.MulByConst",
|
||||
"Ops.MulByConstTo",
|
||||
"Ops.MulByConstAndAdd",
|
||||
"Ops.MulByConstAndAddTile",
|
||||
"Ops.MulByConstAndAddTile4",
|
||||
"Ops.MulByConstAndAddVector",
|
||||
"Ops.Softmax",
|
||||
"Ops.LogitsSoftCap",
|
||||
"FlashAttention.TransposeQ",
|
||||
"FlashAttention.RMSNormAndPositionalEncoding",
|
||||
"FlashAttention.SingleFlashAttention",
|
||||
"FlashAttention.TileFlashAttention",
|
||||
"FlashAttention.TileFlashAttention4",
|
||||
"FlashAttention.FlashAttention",
|
||||
"Gen.Activation",
|
||||
"Gen.ActivationFused",
|
||||
"Gen.SampleTop1",
|
||||
"Gen.SampleTopK",
|
||||
"Gen.Attention.QDotK",
|
||||
"Gen.Attention.DotSoftmaxWeightedSum.par",
|
||||
"Startup.Weights.ReadAllToBF16",
|
||||
"Startup.Weights.ReadBatches",
|
||||
"MM.Dispatch",
|
||||
"MM.MatMul",
|
||||
"MM.TwoMatMul",
|
||||
"MM.DecompressA",
|
||||
"MM.NT",
|
||||
"MM.NT_K",
|
||||
"MM.NT_MT",
|
||||
"MM.NT_MT_K",
|
||||
};
|
||||
|
||||
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones];
|
||||
#endif
|
||||
|
||||
void InitProfilerZones(hwy::Profiler& profiler) {
|
||||
#if PROFILER_ENABLED
|
||||
// Initialize the zone handles. This is done once at startup.
|
||||
for (size_t i = 0; i < kNumZones; ++i) {
|
||||
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]);
|
||||
const char* ZoneName(Zones zone) {
|
||||
switch (zone) {
|
||||
case Zones::kFlashAttentionFlashAttention:
|
||||
return "FlashAttention.FlashAttention";
|
||||
case Zones::kFlashAttentionInclusive:
|
||||
return "FlashAttention.Inclusive";
|
||||
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
|
||||
return "FlashAttention.RMSNormAndPositionalEncoding";
|
||||
case Zones::kFlashAttentionSingleFlashAttention:
|
||||
return "FlashAttention.SingleFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention:
|
||||
return "FlashAttention.TileFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention4:
|
||||
return "FlashAttention.TileFlashAttention4";
|
||||
case Zones::kFlashAttentionTransposeQ:
|
||||
return "FlashAttention.TransposeQ";
|
||||
case Zones::kGenActivation:
|
||||
return "Gen.Activation";
|
||||
case Zones::kGenActivationFused:
|
||||
return "Gen.ActivationFused";
|
||||
case Zones::kGenAttention:
|
||||
return "Gen.Attention";
|
||||
case Zones::kGenAttentionComputeQKV:
|
||||
return "Gen.Attention.ComputeQKV";
|
||||
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
|
||||
return "Gen.Attention.DotSoftmaxWeightedSumInclusive";
|
||||
case Zones::kGenAttentionDotSoftmaxWeightedSumPar:
|
||||
return "Gen.Attention.DotSoftmaxWeightedSum.par";
|
||||
case Zones::kGenAttentionQDotK:
|
||||
return "Gen.Attention.QDotK";
|
||||
case Zones::kGenAttentionSumHeads:
|
||||
return "Gen.Attention.SumHeads";
|
||||
case Zones::kGenEmbed:
|
||||
return "Gen.Embed";
|
||||
case Zones::kGenEmbeddingMatmul:
|
||||
return "Gen.EmbeddingMatmul";
|
||||
case Zones::kGenFFW:
|
||||
return "Gen.FFW";
|
||||
case Zones::kGenSampleTop1:
|
||||
return "Gen.SampleTop1";
|
||||
case Zones::kGenSampleTopK:
|
||||
return "Gen.SampleTopK";
|
||||
case Zones::kMMDecompressA:
|
||||
return "MM.DecompressA";
|
||||
case Zones::kMMDispatch:
|
||||
return "MM.Dispatch";
|
||||
case Zones::kMMMatMul:
|
||||
return "MM.MatMul";
|
||||
case Zones::kMMNT_K:
|
||||
return "MM.NT_K";
|
||||
case Zones::kMMNT_MT_K:
|
||||
return "MM.NT_MT_K";
|
||||
case Zones::kMMNT_MT:
|
||||
return "MM.NT_MT";
|
||||
case Zones::kMMNT:
|
||||
return "MM.NT";
|
||||
case Zones::kMMTwoMatMul:
|
||||
return "MM.TwoMatMul";
|
||||
case Zones::kOpsAddFrom:
|
||||
return "Ops.AddFrom";
|
||||
case Zones::kOpsLogitsSoftCap:
|
||||
return "Ops.LogitsSoftCap";
|
||||
// case Zones::kOpsMulByConst: // removed due to overhead
|
||||
// case Zones::kOpsMulByConstAndAdd: // removed due to overhead
|
||||
// case Zones::kOpsMulByConstAndAddTile: // removed due to overhead
|
||||
// case Zones::kOpsMulByConstAndAddTile4: // removed due to overhead
|
||||
// case Zones::kOpsMulByConstAndAddVector: // removed due to overhead
|
||||
case Zones::kOpsMulByConstTo:
|
||||
return "Ops.MulByConstTo";
|
||||
case Zones::kOpsRmsNorm:
|
||||
return "Ops.RMSNorm";
|
||||
case Zones::kOpsRmsNormInplace:
|
||||
return "Ops.RMSNormInplace";
|
||||
case Zones::kOpsRmsNormMul:
|
||||
return "Ops.RMSNormMul";
|
||||
case Zones::kOpsRope:
|
||||
return "Ops.Rope";
|
||||
case Zones::kOpsRopeAndMulBy:
|
||||
return "Ops.RopeAndMulBy";
|
||||
case Zones::kOpsSoftmax:
|
||||
return "Ops.Softmax";
|
||||
case Zones::kStartupWeightsReadAllToBF16:
|
||||
return "Startup.Weights.ReadAllToBF16";
|
||||
case Zones::kStartupWeightsReadBatches:
|
||||
return "Startup.Weights.ReadBatches";
|
||||
default:
|
||||
HWY_ABORT("Invalid zone %d.", static_cast<int>(zone));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) {
|
||||
#if PROFILER_ENABLED
|
||||
return profiler_zone_handles[static_cast<size_t>(zone)];
|
||||
#else
|
||||
return hwy::profiler::ZoneHandle();
|
||||
#endif
|
||||
hwy::ProfilerFlags ZoneFlags(Zones zone) {
|
||||
switch (zone) {
|
||||
case Zones::kFlashAttentionInclusive:
|
||||
case Zones::kGenAttention:
|
||||
case Zones::kGenAttentionComputeQKV:
|
||||
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
|
||||
case Zones::kGenAttentionSumHeads:
|
||||
case Zones::kGenEmbed:
|
||||
case Zones::kGenEmbeddingMatmul:
|
||||
case Zones::kGenFFW:
|
||||
return hwy::ProfilerFlags::kInclusive;
|
||||
default:
|
||||
return hwy::ProfilerFlags::kDefault;
|
||||
}
|
||||
}
|
||||
|
||||
const char* CallerName(Callers caller) {
|
||||
switch (caller) {
|
||||
case Callers::kActivationBatched:
|
||||
return "ActivationBatched";
|
||||
case Callers::kAllocateAndBindAll:
|
||||
return "AllocateAndBindAll";
|
||||
case Callers::kAttComputeQKV:
|
||||
return "Att.ComputeQKV";
|
||||
case Callers::kAttDotSoftmaxWeightedSum:
|
||||
return "Att.DotSoftmaxWeightedSum";
|
||||
case Callers::kBlobWriter:
|
||||
return "BlobWriter";
|
||||
case Callers::kCompress:
|
||||
return "Compress";
|
||||
case Callers::kFixupWeights:
|
||||
return "FixupWeights";
|
||||
case Callers::kFlashAttention:
|
||||
return "FlashAttention";
|
||||
case Callers::kFlashRMSNormAndPositionalEncoding:
|
||||
return "Flash.RMSNormAndPositionalEncoding";
|
||||
case Callers::kFlashTransposeQ:
|
||||
return "Flash.TransposeQ";
|
||||
case Callers::kMMClusterForMC:
|
||||
return "MM.ClusterForMC";
|
||||
case Callers::kMMClusterForMCNC:
|
||||
return "MM.ClusterForMCNC";
|
||||
case Callers::kMMClusterForN:
|
||||
return "MM.ClusterForN";
|
||||
case Callers::kMMHierForMC:
|
||||
return "MM.HierForMC";
|
||||
case Callers::kMMHierForMCNC:
|
||||
return "MM.HierForMCNC";
|
||||
case Callers::kMMHierForN:
|
||||
return "MM.HierForN";
|
||||
case Callers::kOpsAddFromBatched:
|
||||
return "Ops.AddFromBatched";
|
||||
case Callers::kOpsMaybeLogitsSoftCapBatched:
|
||||
return "Ops.MaybeLogitsSoftCapBatched";
|
||||
case Callers::kOpsRMSNormBatched:
|
||||
return "Ops.RMSNormBatched";
|
||||
case Callers::kOpsRMSNormInplaceBatched:
|
||||
return "Ops.RMSNormInplaceBatched";
|
||||
case Callers::kReadAllToBF16:
|
||||
return "ReadAllToBF16";
|
||||
case Callers::kReadBatches:
|
||||
return "ReadBatches";
|
||||
case Callers::kSampleAndStream:
|
||||
return "SampleAndStream";
|
||||
case Callers::kTest: // only for unit tests.
|
||||
return "Test-only!";
|
||||
case Callers::kTunePool:
|
||||
return "TunePool";
|
||||
case Callers::kVitDotSoftmax1:
|
||||
return "Vit.DotSoftmax1";
|
||||
case Callers::kVitDotSoftmax2:
|
||||
return "Vit.DotSoftmax2";
|
||||
case Callers::kVitDotSoftmax3:
|
||||
return "Vit.DotSoftmax3";
|
||||
case Callers::kVitDotSoftmax4:
|
||||
return "Vit.DotSoftmax4";
|
||||
default:
|
||||
HWY_ABORT("Invalid caller %d.", static_cast<int>(caller));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ProfilerZones::ProfilerZones(hwy::Profiler& profiler) {
|
||||
for (size_t i = 0;; ++i) {
|
||||
const Zones zone = static_cast<Zones>(i);
|
||||
if (zone == Zones::kNumZones) break;
|
||||
handles_[i] = profiler.AddZone(ZoneName(zone), ZoneFlags(zone));
|
||||
}
|
||||
}
|
||||
|
||||
PoolCallers::PoolCallers() {
|
||||
for (size_t i = 0;; ++i) {
|
||||
const Callers caller = static_cast<Callers>(i);
|
||||
if (caller == Callers::kNumCallers) break;
|
||||
callers_[i] = hwy::ThreadPool::AddCaller(CallerName(caller));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
128
util/zones.h
128
util/zones.h
|
|
@ -1,57 +1,123 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Zones for the profiler.
|
||||
enum class Zones {
|
||||
kOpsRmsNormMul,
|
||||
kOpsRmsNorm,
|
||||
kOpsRmsNormInplace,
|
||||
kOpsRope,
|
||||
kOpsRopeAndMulBy,
|
||||
kOpsAddFrom,
|
||||
kOpsMulByConst,
|
||||
kOpsMulByConstTo,
|
||||
kOpsMulByConstAndAdd,
|
||||
kOpsMulByConstAndAddTile,
|
||||
kOpsMulByConstAndAddTile4,
|
||||
kOpsMulByConstAndAddVector,
|
||||
kOpsSoftmax,
|
||||
kOpsLogitsSoftCap,
|
||||
kFlashAttentionTransposeQ,
|
||||
enum class Zones { // Keep sorted
|
||||
kFlashAttentionFlashAttention,
|
||||
kFlashAttentionInclusive,
|
||||
kFlashAttentionRmsNormAndPositionalEncoding,
|
||||
kFlashAttentionSingleFlashAttention,
|
||||
kFlashAttentionTileFlashAttention,
|
||||
kFlashAttentionTileFlashAttention4,
|
||||
kFlashAttentionFlashAttention,
|
||||
kFlashAttentionTransposeQ,
|
||||
kGenActivation,
|
||||
kGenActivationFused,
|
||||
kGenAttention,
|
||||
kGenAttentionComputeQKV,
|
||||
kGenAttentionDotSoftmaxWeightedSumInclusive,
|
||||
kGenAttentionDotSoftmaxWeightedSumPar,
|
||||
kGenAttentionQDotK,
|
||||
kGenAttentionSumHeads,
|
||||
kGenEmbed,
|
||||
kGenEmbeddingMatmul,
|
||||
kGenFFW,
|
||||
kGenSampleTop1,
|
||||
kGenSampleTopK,
|
||||
kGenAttentionQDotK,
|
||||
kGenAttentionDotSoftmaxWeightedSumPar,
|
||||
kStartupWeightsReadAllToBF16,
|
||||
kStartupWeightsReadBatches,
|
||||
kMMDecompressA,
|
||||
kMMDispatch,
|
||||
kMMMatMul,
|
||||
kMMTwoMatMul,
|
||||
kMMDecompressA,
|
||||
kMMNT,
|
||||
kMMNT_K,
|
||||
kMMNT_MT,
|
||||
kMMNT_MT_K,
|
||||
kNumZones
|
||||
kMMNT_MT,
|
||||
kMMNT,
|
||||
kMMTwoMatMul,
|
||||
kOpsAddFrom,
|
||||
kOpsLogitsSoftCap,
|
||||
// kOpsMulByConst, // removed due to overhead
|
||||
// kOpsMulByConstAndAdd, // removed due to overhead
|
||||
// kOpsMulByConstAndAddTile, // removed due to overhead
|
||||
// kOpsMulByConstAndAddTile4, // removed due to overhead
|
||||
// kOpsMulByConstAndAddVector, // removed due to overhead
|
||||
kOpsMulByConstTo,
|
||||
kOpsRmsNorm,
|
||||
kOpsRmsNormInplace,
|
||||
kOpsRmsNormMul,
|
||||
kOpsRope,
|
||||
kOpsRopeAndMulBy,
|
||||
kOpsSoftmax,
|
||||
kStartupWeightsReadAllToBF16,
|
||||
kStartupWeightsReadBatches,
|
||||
kNumZones // must be last
|
||||
};
|
||||
|
||||
// Initializes the profiler zones. Must be called before any other profiler
|
||||
// functions.
|
||||
void InitProfilerZones(hwy::Profiler& profiler);
|
||||
// Owned by ThreadingContext.
|
||||
class ProfilerZones {
|
||||
public:
|
||||
ProfilerZones(hwy::Profiler& profiler);
|
||||
|
||||
// Returns the zone handle for the given zone enum value.
|
||||
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone);
|
||||
hwy::profiler::ZoneHandle Get(Zones zone) {
|
||||
HWY_DASSERT(zone != Zones::kNumZones);
|
||||
return handles_[static_cast<size_t>(zone)];
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::profiler::ZoneHandle handles_[static_cast<size_t>(Zones::kNumZones)];
|
||||
};
|
||||
|
||||
enum class Callers { // Keep sorted
|
||||
kActivationBatched,
|
||||
kAllocateAndBindAll,
|
||||
kAttComputeQKV,
|
||||
kAttDotSoftmaxWeightedSum,
|
||||
kBlobWriter,
|
||||
kCompress,
|
||||
kFixupWeights,
|
||||
kFlashAttention,
|
||||
kFlashRMSNormAndPositionalEncoding,
|
||||
kFlashTransposeQ,
|
||||
kMMClusterForMC,
|
||||
kMMClusterForMCNC,
|
||||
kMMClusterForN,
|
||||
kMMHierForMC,
|
||||
kMMHierForMCNC,
|
||||
kMMHierForN,
|
||||
kOpsAddFromBatched,
|
||||
kOpsMaybeLogitsSoftCapBatched,
|
||||
kOpsRMSNormBatched,
|
||||
kOpsRMSNormInplaceBatched,
|
||||
kReadAllToBF16,
|
||||
kReadBatches,
|
||||
kSampleAndStream,
|
||||
kTest, // only for unit tests.
|
||||
kTunePool,
|
||||
kVitDotSoftmax1,
|
||||
kVitDotSoftmax2,
|
||||
kVitDotSoftmax3,
|
||||
kVitDotSoftmax4,
|
||||
kNumCallers // must be last
|
||||
};
|
||||
|
||||
// Owned by ThreadingContext.
|
||||
class PoolCallers {
|
||||
public:
|
||||
PoolCallers();
|
||||
|
||||
hwy::pool::Caller Get(Callers caller) {
|
||||
HWY_DASSERT(caller != Callers::kNumCallers);
|
||||
return callers_[static_cast<size_t>(caller)];
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::pool::Caller callers_[static_cast<size_t>(Callers::kNumCallers)];
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue