mirror of https://github.com/google/gemma.cpp.git
Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions. Used ThreadingContext arg instead of NestedPools. Use new PROFILER_ZONE3. PiperOrigin-RevId: 793865287
This commit is contained in:
parent
4cbf63e6f0
commit
a2d9133f7d
|
|
@ -80,7 +80,6 @@ cc_library(
|
|||
":topology",
|
||||
# Placeholder for container detection, do not remove
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
|
|
@ -380,7 +379,6 @@ cc_test(
|
|||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "92d327e841d78e11ae888757a3e16d291951cf64",
|
||||
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -469,7 +469,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -84,9 +84,8 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
|
||||
hwy::Profiler& p) {
|
||||
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
||||
Softmax(logits, vocab_size, /*worker=*/0);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -110,7 +109,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
const SampleFunc sample_token = [&](float* probs,
|
||||
size_t vocab_size) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||
HWY_ASSERT(pos < prompt.size());
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/threading_context.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -54,9 +53,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Gen.Attention.QDotK");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
|
|
@ -75,8 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
const size_t pos, const float mul = 1.0f) {
|
||||
const size_t worker, const size_t pos,
|
||||
const float mul = 1.0f) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
|
|
@ -88,10 +86,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
}
|
||||
// PostQKType::Rope
|
||||
if (post_qk == PostQKType::HalfRope) {
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -99,31 +97,26 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||
const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v,
|
||||
float* HWY_RESTRICT att_out,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
// TODO: 2..4x unroll
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
|
||||
worker);
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
|
||||
worker);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,7 +128,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||
float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
|
|
@ -145,21 +138,21 @@ void SingleDotSoftmaxWeightedSum(
|
|||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
||||
layer.layer_config.qkv_dim, p, worker);
|
||||
layer.layer_config.qkv_dim, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
|
||||
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
||||
Softmax(att, att_len, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
worker);
|
||||
}
|
||||
|
||||
|
|
@ -174,8 +167,9 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
|||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
|
||||
NestedPools& pools) {
|
||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
|
@ -195,7 +189,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||
const size_t head = activations.div_heads.Remainder(task);
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
#if PROFILER_ENABLED
|
||||
const hwy::Zone zone(worker, zone_id_par);
|
||||
#endif
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
|
|
@ -227,15 +223,14 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, ctx.profiler,
|
||||
worker);
|
||||
layer, activations, att, att_out, worker);
|
||||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
ctx.pools, func);
|
||||
pools, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -308,12 +303,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
||||
env.ctx.profiler, thread);
|
||||
thread);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||
env.ctx.profiler, thread, pos);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
|
||||
pos);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
|
|
@ -344,10 +339,6 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.Attention", ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
|
|
@ -356,7 +347,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx);
|
||||
env.ctx.pools);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ namespace gcpp {
|
|||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, ThreadingContext& ctx); \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
|
|
|
|||
|
|
@ -45,10 +45,9 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
template <typename T>
|
||||
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||
const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||
const T* HWY_RESTRICT c2, const size_t count,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Gen.Activation");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
PROFILER_ZONE2(worker, "Gen.Activation");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<T>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -65,30 +64,28 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
|||
|
||||
// No C2 multiplier.
|
||||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1,
|
||||
ThreadingContext& ctx) {
|
||||
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
||||
using T = typename Mat::T;
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
c1.Cols(), worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Mat>
|
||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
||||
const Mat* c2, ThreadingContext& ctx) {
|
||||
const Mat* c2, NestedPools& pools) {
|
||||
using T = typename Mat::T;
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker);
|
||||
});
|
||||
} else { // No multiplier
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
c1.Cols(), worker);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -113,9 +110,7 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
|||
|
||||
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
static const auto zone =
|
||||
env.ctx.profiler.AddZone("Gen.FFW", ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
||||
|
|
@ -134,7 +129,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||
env.ctx);
|
||||
env.ctx.pools);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@
|
|||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/basics.h" // PROFILER_ZONE3
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -139,8 +138,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||
const ImageTokens* image_tokens = nullptr,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
// Image tokens just need to be copied.
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
|
|
@ -176,8 +174,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
|
||||
ctx.profiler, worker);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
|
|
@ -252,7 +249,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
env.ctx, runtime_config.image_tokens, image_token_position);
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
|
|
@ -309,7 +306,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
// TODO: parallelize?
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x, env.ctx);
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
|
|
@ -422,8 +419,7 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
const size_t worker = 0; // TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
|
||||
env.ctx.profiler, worker);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
|
|
@ -433,28 +429,27 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
}
|
||||
|
||||
static HWY_INLINE SampleFunc
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE2(worker, "Gen.Sample Top1");
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
};
|
||||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
return [&runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
|
||||
worker);
|
||||
runtime_config.temperature, runtime_config.accept_token, worker);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -529,7 +524,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
max_gen_steps = seq_len - max_prompt_size;
|
||||
}
|
||||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
|
||||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
|
|
|
|||
18
gemma/vit.cc
18
gemma/vit.cc
|
|
@ -95,7 +95,7 @@ class VitAttention {
|
|||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class VitAttention {
|
|||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
|
||||
Softmax(c, C.Cols(), worker);
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
|
|
@ -121,8 +121,7 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -145,7 +144,7 @@ class VitAttention {
|
|||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
|
|
@ -154,7 +153,7 @@ class VitAttention {
|
|||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
|
||||
Softmax(head_att, seq_len, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
@ -162,8 +161,7 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -226,7 +224,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
|||
activations.C1);
|
||||
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_config.activation, activations.C1, env.ctx);
|
||||
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
||||
|
|
@ -336,7 +334,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
|||
// Apply soft embedding norm before input projection.
|
||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
||||
vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread());
|
||||
vit_model_dim, /*worker=*/0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -381,11 +381,9 @@ static void DecompressToBF16(MatPtr& mat,
|
|||
}
|
||||
|
||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
static const auto zone =
|
||||
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
||||
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const BlobReader& reader, hwy::ThreadPool& pool) {
|
||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
|
|
@ -462,11 +460,10 @@ static std::vector<IOBatch> MakeBatches(
|
|||
// want to use the OS cache between consecutive runs.
|
||||
static void ReadBatches(const BlobReader& reader,
|
||||
const std::vector<IOBatch>& batches,
|
||||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
||||
hwy::ThreadPool& pool) {
|
||||
// >5x speedup from parallel reads when cached.
|
||||
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE2(thread, "Startup.Weights.Read");
|
||||
const IOBatch& batch = batches[i];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
|
|
@ -503,14 +500,16 @@ static MapPtr MapOrReadAll(std::vector<TensorToRead>& tensors,
|
|||
AllocateAndBindAll(tensors, *mode, mat_owners, ctx);
|
||||
}
|
||||
|
||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
|
||||
if (*mode == WeightsPtrs::Mode::kReadBF16) {
|
||||
ReadAllToBF16(tensors, reader, ctx);
|
||||
ReadAllToBF16(tensors, reader, pool);
|
||||
return MapPtr();
|
||||
}
|
||||
|
||||
const std::vector<IOBatch> batches =
|
||||
MakeBatches(tensors, reader.file_bytes());
|
||||
ReadBatches(reader, batches, ctx);
|
||||
ReadBatches(reader, batches, pool);
|
||||
return MapPtr();
|
||||
}
|
||||
|
||||
|
|
@ -520,7 +519,7 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
|||
const InferenceArgs& inference,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
ThreadingContext& ctx) {
|
||||
PROFILER_ZONE("Startup.Weights.ReadFromBlobs");
|
||||
PROFILER_ZONE("Startup.ReadFromBlobs");
|
||||
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<TensorToRead> tensors;
|
||||
|
|
|
|||
17
ops/matmul.h
17
ops/matmul.h
|
|
@ -720,33 +720,32 @@ struct MMArgs {
|
|||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||
#if PROFILER_ENABLED
|
||||
class MMZone {
|
||||
using Zone = hwy::profiler::Zone;
|
||||
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 16);
|
||||
using Zone = hwy::Zone;
|
||||
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8);
|
||||
|
||||
public:
|
||||
~MMZone() {
|
||||
if (data_ != 0) {
|
||||
if (used_) {
|
||||
Zone* zone = reinterpret_cast<Zone*>(&data_);
|
||||
zone->~Zone();
|
||||
}
|
||||
}
|
||||
|
||||
// `name` must be a string literal.
|
||||
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
|
||||
const MMArgs& args) {
|
||||
void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) {
|
||||
if (args.per_key->WantProfile()) {
|
||||
new (&data_) Zone(args.env->ctx.profiler, thread, zone);
|
||||
HWY_DASSERT(data_ != 0);
|
||||
new (&data_) Zone(thread_id, zone_id);
|
||||
used_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t data_ = 0;
|
||||
uint64_t data2_ = 0;
|
||||
bool used_ = false;
|
||||
};
|
||||
#else
|
||||
struct MMZone {
|
||||
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
|
||||
void MaybeEnter(size_t, uint32_t, const MMArgs&) {}
|
||||
};
|
||||
#endif // PROFILER_ENABLED
|
||||
|
||||
|
|
|
|||
122
ops/ops-inl.h
122
ops/ops-inl.h
|
|
@ -125,9 +125,9 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
|||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
||||
// Activation already has a profiler zone.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.Gelu");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
hn::Transform(D(), x, size,
|
||||
|
|
@ -191,10 +191,9 @@ namespace detail {
|
|||
|
||||
// Shared by RMSNorm and RMSNormInplace.
|
||||
template <typename VT>
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNormMul");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNormMul");
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||
|
|
@ -206,20 +205,18 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
|||
|
||||
// `x_ofs` is the offset within `x`, required for NuqStream.
|
||||
template <typename XT, typename WT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||
const WT* HWY_RESTRICT weight,
|
||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNorm");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
|
||||
OT* HWY_RESTRICT out, const size_t size,
|
||||
const size_t HWY_MAYBE_UNUSED worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNorm");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker));
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker));
|
||||
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
|
|
@ -243,16 +240,15 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
|||
template <typename WT, typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RMSNormInplace");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker));
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||
const auto packed_x = MakeSpan(inout, size);
|
||||
|
|
@ -411,10 +407,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.Rope");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||
PROFILER_ZONE2(worker, "ops.Rope");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -471,10 +466,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.RopeAndMulBy");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||
PROFILER_ZONE2(worker, "ops.RopeAndMulBy");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
|
|
@ -531,13 +525,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
}
|
||||
|
||||
template <typename XT>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||
float* HWY_RESTRICT out,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.AddFrom");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.AddFrom");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -585,10 +576,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
HWY_DASSERT(activations.SameShape(out));
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(
|
||||
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
|
||||
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
||||
SmallParallelFor(activations.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx),
|
||||
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
||||
activations.Cols(), worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -600,10 +592,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
|||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(
|
||||
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
|
||||
inout.Cols(), ctx.profiler, worker);
|
||||
SmallParallelFor(inout.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
||||
inout.Row(token_idx), inout.Cols(),
|
||||
worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -629,20 +622,17 @@ template <typename XT>
|
|||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||
ThreadingContext& ctx) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
SmallParallelFor(out.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||
ctx.profiler, worker);
|
||||
SmallParallelFor(
|
||||
out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConst");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
||||
const float c, XT* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConst");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -682,9 +672,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConstTo");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstTo");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -725,9 +714,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -772,10 +760,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker,
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
static const auto zone = p.AddZone("Ops.Softmax");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
PROFILER_ZONE2(worker, "ops.Softmax");
|
||||
HWY_DASSERT(size != 0);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -816,7 +803,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const float sum_exp = Sum(d, x, size);
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, x, size, p, worker);
|
||||
MulByConst(mul, x, size, worker);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -906,10 +893,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
const size_t size, hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -925,10 +911,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
|
||||
// Calls LogitsSoftCap if cap != 0.0f.
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t worker) {
|
||||
if (cap != 0.0f) {
|
||||
LogitsSoftCap(cap, x, size, p, worker);
|
||||
LogitsSoftCap(cap, x, size, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1012,7 +998,7 @@ template <typename TAcceptToken>
|
|||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
||||
hwy::Profiler& p, size_t worker) {
|
||||
size_t worker) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
|
|
@ -1026,7 +1012,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
}
|
||||
|
||||
size_t mask = token_logits.size();
|
||||
Softmax(topk_logits.data(), mask, p, worker, temperature);
|
||||
Softmax(topk_logits.data(), mask, worker, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(
|
||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "ops/ops.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -31,13 +33,11 @@
|
|||
|
||||
#include "gemma/activations.h" // ChooseQueryScale
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -166,7 +166,7 @@ struct TestAddFrom {
|
|||
}
|
||||
|
||||
SimpleAddFrom(o, e, count);
|
||||
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
AddFrom(o, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
MulByConstAndAdd(constant, o, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -229,7 +229,7 @@ struct TestMulByConst {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConst(constant, e, count);
|
||||
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
MulByConst(constant, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -259,7 +259,7 @@ struct TestSoftmax {
|
|||
}
|
||||
|
||||
SimpleSoftmax(e, count);
|
||||
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
Softmax(x, count, /*worker=*/0);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -349,9 +349,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
void TestRopeAndMulBy() {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::Profiler& p = ctx.profiler;
|
||||
const size_t worker = 0;
|
||||
|
||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
|
|
@ -384,8 +381,7 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
worker);
|
||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
|
@ -395,7 +391,7 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, qactual);
|
||||
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||
}
|
||||
|
|
@ -406,10 +402,9 @@ void TestRopeAndMulBy() {
|
|||
CopyMat(x, kactual2);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||
worker);
|
||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
static_assert(kmul == 1.0f, "");
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
||||
|
||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||
|
|
@ -459,7 +454,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
|
|||
}
|
||||
|
||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||
RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
||||
RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
|
|
@ -585,13 +580,11 @@ void TestAllLayerNorm() {
|
|||
}
|
||||
|
||||
void TestSampleTopK() {
|
||||
hwy::Profiler& p = hwy::Profiler::Get();
|
||||
const size_t worker = 0;
|
||||
const size_t kSize = 52;
|
||||
std::vector<float> logits(kSize);
|
||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
float temperature = 1.0f;
|
||||
|
|
@ -607,7 +600,7 @@ void TestSampleTopK() {
|
|||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits.data(), kSize, p, worker);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@
|
|||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -172,8 +171,6 @@ NestedPools::NestedPools(const BoundedTopology& topology,
|
|||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
|
||||
hwy::Profiler::Get().SetMaxThreads(MaxWorkers());
|
||||
}
|
||||
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
|
|
|
|||
|
|
@ -72,8 +72,7 @@ static void TunePool(hwy::ThreadPool& pool) {
|
|||
}
|
||||
|
||||
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
||||
: profiler(hwy::Profiler::Get()),
|
||||
topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||
: topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||
BoundedSlice(args.skip_clusters, args.max_clusters),
|
||||
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||
allocator(topology, args.bind != Tristate::kFalse),
|
||||
|
|
|
|||
|
|
@ -90,7 +90,6 @@ struct ThreadingContext {
|
|||
// Expected to be called early in the program, before threading starts.
|
||||
explicit ThreadingContext(const ThreadingArgs& args);
|
||||
|
||||
hwy::Profiler& profiler;
|
||||
BoundedTopology topology;
|
||||
Allocator allocator;
|
||||
NestedPools pools;
|
||||
|
|
|
|||
Loading…
Reference in New Issue