mirror of https://github.com/google/gemma.cpp.git
(Resubmit) Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions. Used ThreadingContext arg instead of NestedPools. Use new PROFILER_ZONE3. PiperOrigin-RevId: 794461159
This commit is contained in:
parent
a2d9133f7d
commit
faa4102992
|
|
@ -80,6 +80,7 @@ cc_library(
|
|||
":topology",
|
||||
# Placeholder for container detection, do not remove
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
|
|
@ -379,6 +380,7 @@ cc_test(
|
|||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34",
|
||||
commit = "92d327e841d78e11ae888757a3e16d291951cf64",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -469,7 +469,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -84,8 +84,9 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
||||
Softmax(logits, vocab_size, /*worker=*/0);
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
|
||||
hwy::Profiler& p) {
|
||||
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -109,7 +110,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
const SampleFunc sample_token = [&](float* probs,
|
||||
size_t vocab_size) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||
HWY_ASSERT(pos < prompt.size());
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/threading_context.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -53,8 +54,9 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Gen.Attention.QDotK");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
|
|
@ -73,8 +75,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const size_t worker, const size_t pos,
|
||||
const float mul = 1.0f) {
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
const size_t pos, const float mul = 1.0f) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
|
|
@ -86,10 +88,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
}
|
||||
// PostQKType::Rope
|
||||
if (post_qk == PostQKType::HalfRope) {
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -97,26 +99,31 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||
const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v,
|
||||
float* HWY_RESTRICT att_out,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
// TODO: 2..4x unroll
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
|
||||
worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
|
||||
worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -128,7 +135,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
|
|
@ -138,21 +145,21 @@ void SingleDotSoftmaxWeightedSum(
|
|||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
||||
layer.layer_config.qkv_dim, worker);
|
||||
layer.layer_config.qkv_dim, p, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
||||
Softmax(att, att_len, worker, /*temperature=*/1.0f);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
|
||||
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||
worker);
|
||||
}
|
||||
|
||||
|
|
@ -167,9 +174,8 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
|||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
||||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
|
@ -189,9 +195,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||
const size_t head = activations.div_heads.Remainder(task);
|
||||
#if PROFILER_ENABLED
|
||||
const hwy::Zone zone(worker, zone_id_par);
|
||||
#endif
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
|
|
@ -223,14 +227,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, worker);
|
||||
layer, activations, att, att_out, ctx.profiler,
|
||||
worker);
|
||||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
pools, func);
|
||||
ctx.pools, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -303,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
||||
thread);
|
||||
env.ctx.profiler, thread);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
|
||||
pos);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||
env.ctx.profiler, thread, pos);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
|
|
@ -339,6 +344,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.Attention", ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
|
|
@ -347,7 +356,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx.pools);
|
||||
env.ctx);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ namespace gcpp {
|
|||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, size_t worker); \
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
QBatch& qbatch, ThreadingContext& ctx); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
|
|
|
|||
|
|
@ -45,9 +45,10 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
template <typename T>
|
||||
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||
const T* HWY_RESTRICT c2, const size_t count,
|
||||
const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE2(worker, "Gen.Activation");
|
||||
static const auto zone = p.AddZone("Gen.Activation");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<T>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -64,28 +65,30 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
|||
|
||||
// No C2 multiplier.
|
||||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
||||
void ActivationBatched(ActivationType activation, Mat& c1,
|
||||
ThreadingContext& ctx) {
|
||||
using T = typename Mat::T;
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), worker);
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Mat>
|
||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
||||
const Mat* c2, NestedPools& pools) {
|
||||
const Mat* c2, ThreadingContext& ctx) {
|
||||
using T = typename Mat::T;
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker);
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
} else { // No multiplier
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), worker);
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -110,7 +113,9 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
|||
|
||||
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.FFW", ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
||||
|
|
@ -129,7 +134,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||
env.ctx.pools);
|
||||
env.ctx);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@
|
|||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/basics.h" // PROFILER_ZONE3
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -138,7 +139,8 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||
const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
// Image tokens just need to be copied.
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
|
|
@ -174,7 +176,8 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
|
|
@ -249,7 +252,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
env.ctx, runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
|
|
@ -306,7 +309,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
// TODO: parallelize?
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x, env.ctx);
|
||||
}
|
||||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
|
|
@ -419,7 +422,8 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
const size_t worker = 0; // TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
|
||||
env.ctx.profiler, worker);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
|
|
@ -429,27 +433,28 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
}
|
||||
|
||||
static HWY_INLINE SampleFunc
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
||||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE2(worker, "Gen.Sample Top1");
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
};
|
||||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [&runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token, worker);
|
||||
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
|
||||
worker);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -524,7 +529,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
max_gen_steps = seq_len - max_prompt_size;
|
||||
}
|
||||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
||||
|
||||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
|
|
|
|||
18
gemma/vit.cc
18
gemma/vit.cc
|
|
@ -95,7 +95,7 @@ class VitAttention {
|
|||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class VitAttention {
|
|||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
Softmax(c, C.Cols(), worker);
|
||||
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
|
|
@ -121,7 +121,8 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker);
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -144,7 +145,7 @@ class VitAttention {
|
|||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
|
|
@ -153,7 +154,7 @@ class VitAttention {
|
|||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len, worker);
|
||||
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
@ -161,7 +162,8 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker);
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -224,7 +226,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
|||
activations.C1);
|
||||
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools);
|
||||
ActivationBatched(layer_config.activation, activations.C1, env.ctx);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
||||
|
|
@ -334,7 +336,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
|||
// Apply soft embedding norm before input projection.
|
||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
||||
vit_model_dim, /*worker=*/0);
|
||||
vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread());
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -381,9 +381,11 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
}
|
||||
|
||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, hwy::ThreadPool& pool) {
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
|
||||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
static const auto zone =
|
||||
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
||||
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
|
|
@ -460,10 +462,11 @@ static std::vector<IOBatch> MakeBatches(
|
|||
// want to use the OS cache between consecutive runs.
|
||||
static void ReadBatches(const BlobReader& reader,
|
||||
const std::vector<IOBatch>& batches,
|
||||
hwy::ThreadPool& pool) {
|
||||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.Read");
|
||||
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const IOBatch& batch = batches[i];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
|
|
@ -500,16 +503,14 @@ static MapPtr MapOrReadAll(std::vector<TensorToRead>& tensors,
|
|||
AllocateAndBindAll(tensors, *mode, mat_owners, ctx);
|
||||
}
|
||||
|
||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
|
||||
if (*mode == WeightsPtrs::Mode::kReadBF16) {
|
||||
ReadAllToBF16(tensors, reader, pool);
|
||||
ReadAllToBF16(tensors, reader, ctx);
|
||||
return MapPtr();
|
||||
}
|
||||
|
||||
const std::vector<IOBatch> batches =
|
||||
MakeBatches(tensors, reader.file_bytes());
|
||||
ReadBatches(reader, batches, pool);
|
||||
ReadBatches(reader, batches, ctx);
|
||||
return MapPtr();
|
||||
}
|
||||
|
||||
|
|
@ -519,7 +520,7 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
|||
const InferenceArgs& inference,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
PROFILER_ZONE("Startup.ReadFromBlobs");
|
||||
PROFILER_ZONE("Startup.Weights.ReadFromBlobs");
|
||||
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<TensorToRead> tensors;
|
||||
|
|
|
|||
17
ops/matmul.h
17
ops/matmul.h
|
|
@ -720,32 +720,33 @@ struct MMArgs {
|
|||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||
#if PROFILER_ENABLED
|
||||
class MMZone {
|
||||
using Zone = hwy::Zone;
|
||||
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8);
|
||||
using Zone = hwy::profiler::Zone;
|
||||
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 16);
|
||||
|
||||
public:
|
||||
~MMZone() {
|
||||
if (used_) {
|
||||
if (data_ != 0) {
|
||||
Zone* zone = reinterpret_cast<Zone*>(&data_);
|
||||
zone->~Zone();
|
||||
}
|
||||
}
|
||||
|
||||
// `name` must be a string literal.
|
||||
void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) {
|
||||
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
|
||||
const MMArgs& args) {
|
||||
if (args.per_key->WantProfile()) {
|
||||
new (&data_) Zone(thread_id, zone_id);
|
||||
used_ = true;
|
||||
new (&data_) Zone(args.env->ctx.profiler, thread, zone);
|
||||
HWY_DASSERT(data_ != 0);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t data_ = 0;
|
||||
bool used_ = false;
|
||||
uint64_t data2_ = 0;
|
||||
};
|
||||
#else
|
||||
struct MMZone {
|
||||
void MaybeEnter(size_t, uint32_t, const MMArgs&) {}
|
||||
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
|
||||
};
|
||||
#endif // PROFILER_ENABLED
|
||||
|
||||
|
|
|
|||
122
ops/ops-inl.h
122
ops/ops-inl.h
|
|
@ -125,9 +125,9 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
|||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
||||
// Activation already has a profiler zone.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.Gelu");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size,
|
||||
|
|
@ -191,9 +191,10 @@ namespace detail {
|
|||
|
||||
// Shared by RMSNorm and RMSNormInplace.
|
||||
template <typename VT>
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNormMul");
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNormMul");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||
|
|
@ -205,18 +206,20 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
|||
|
||||
// `x_ofs` is the offset within `x`, required for NuqStream.
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
|
||||
OT* HWY_RESTRICT out, const size_t size,
|
||||
const size_t HWY_MAYBE_UNUSED worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNorm");
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||
const WT* HWY_RESTRICT weight,
|
||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNorm");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker));
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker));
|
||||
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
|
|
@ -240,15 +243,16 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
template <typename WT, typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNormInplace");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker));
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
const auto packed_x = MakeSpan(inout, size);
|
||||
|
|
@ -407,9 +411,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||
PROFILER_ZONE2(worker, "ops.Rope");
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.Rope");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -466,9 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||
PROFILER_ZONE2(worker, "ops.RopeAndMulBy");
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RopeAndMulBy");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -525,10 +531,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
}
|
||||
|
||||
template <typename XT>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.AddFrom");
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.AddFrom");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -576,11 +585,10 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
HWY_DASSERT(activations.SameShape(out));
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(activations.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx),
|
||||
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
||||
activations.Cols(), worker);
|
||||
SmallParallelFor(
|
||||
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
|
||||
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -592,11 +600,10 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
|||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(inout.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
||||
inout.Row(token_idx), inout.Cols(),
|
||||
worker);
|
||||
SmallParallelFor(
|
||||
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
|
||||
inout.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -622,17 +629,20 @@ template <typename XT>
|
|||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||
ThreadingContext& ctx) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
SmallParallelFor(
|
||||
out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
||||
SmallParallelFor(out.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
||||
const float c, XT* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConst");
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConst");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -672,8 +682,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstTo");
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConstTo");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -714,8 +725,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -760,9 +772,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t worker,
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
PROFILER_ZONE2(worker, "ops.Softmax");
|
||||
static const auto zone = p.AddZone("Ops.Softmax");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_DASSERT(size != 0);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -803,7 +816,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const float sum_exp = Sum(d, x, size);
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, x, size, worker);
|
||||
MulByConst(mul, x, size, p, worker);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -893,9 +906,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
|
||||
const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -911,10 +925,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
|
||||
// Calls LogitsSoftCap if cap != 0.0f.
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
if (cap != 0.0f) {
|
||||
LogitsSoftCap(cap, x, size, worker);
|
||||
LogitsSoftCap(cap, x, size, p, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -998,7 +1012,7 @@ template <typename TAcceptToken>
|
|||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
||||
size_t worker) {
|
||||
hwy::Profiler& p, size_t worker) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
|
|
@ -1012,7 +1026,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
}
|
||||
|
||||
size_t mask = token_logits.size();
|
||||
Softmax(topk_logits.data(), mask, worker, temperature);
|
||||
Softmax(topk_logits.data(), mask, p, worker, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(
|
||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@
|
|||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "ops/ops.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -33,11 +31,13 @@
|
|||
|
||||
#include "gemma/activations.h" // ChooseQueryScale
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -166,7 +166,7 @@ struct TestAddFrom {
|
|||
}
|
||||
|
||||
SimpleAddFrom(o, e, count);
|
||||
AddFrom(o, x, count, /*worker=*/0);
|
||||
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
MulByConstAndAdd(constant, o, x, count, /*worker=*/0);
|
||||
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -229,7 +229,7 @@ struct TestMulByConst {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConst(constant, e, count);
|
||||
MulByConst(constant, x, count, /*worker=*/0);
|
||||
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -259,7 +259,7 @@ struct TestSoftmax {
|
|||
}
|
||||
|
||||
SimpleSoftmax(e, count);
|
||||
Softmax(x, count, /*worker=*/0);
|
||||
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -349,6 +349,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
void TestRopeAndMulBy() {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::Profiler& p = ctx.profiler;
|
||||
const size_t worker = 0;
|
||||
|
||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
|
|
@ -381,7 +384,8 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
worker);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
|
@ -391,7 +395,7 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
|
@ -402,9 +406,10 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, kactual2);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
worker);
|
||||
static_assert(kmul == 1.0f, "");
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||
|
|
@ -454,7 +459,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
|
|||
}
|
||||
|
||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||
RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0);
|
||||
RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
|
|
@ -580,11 +585,13 @@ void TestAllLayerNorm() {
|
|||
}
|
||||
|
||||
void TestSampleTopK() {
|
||||
hwy::Profiler& p = hwy::Profiler::Get();
|
||||
const size_t worker = 0;
|
||||
const size_t kSize = 52;
|
||||
std::vector<float> logits(kSize);
|
||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
float temperature = 1.0f;
|
||||
|
|
@ -600,7 +607,7 @@ void TestSampleTopK() {
|
|||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -171,6 +172,8 @@ NestedPools::NestedPools(const BoundedTopology& topology,
|
|||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
|
||||
hwy::Profiler::Get().SetMaxThreads(MaxWorkers());
|
||||
}
|
||||
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ static void TunePool(hwy::ThreadPool& pool) {
|
|||
}
|
||||
|
||||
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
||||
: topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||
: profiler(hwy::Profiler::Get()),
|
||||
topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||
BoundedSlice(args.skip_clusters, args.max_clusters),
|
||||
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||
allocator(topology, args.bind != Tristate::kFalse),
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ struct ThreadingContext {
|
|||
// Expected to be called early in the program, before threading starts.
|
||||
explicit ThreadingContext(const ThreadingArgs& args);
|
||||
|
||||
hwy::Profiler& profiler;
|
||||
BoundedTopology topology;
|
||||
Allocator allocator;
|
||||
NestedPools pools;
|
||||
|
|
|
|||
Loading…
Reference in New Issue