mirror of https://github.com/google/gemma.cpp.git
(Resubmit) Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions. Used ThreadingContext arg instead of NestedPools. Use new PROFILER_ZONE3. PiperOrigin-RevId: 794461159
This commit is contained in:
parent
a2d9133f7d
commit
faa4102992
|
|
@ -80,6 +80,7 @@ cc_library(
|
||||||
":topology",
|
":topology",
|
||||||
# Placeholder for container detection, do not remove
|
# Placeholder for container detection, do not remove
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@highway//:topology",
|
"@highway//:topology",
|
||||||
],
|
],
|
||||||
|
|
@ -379,6 +380,7 @@ cc_test(
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||||
|
"@highway//:profiler",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64 EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
|
|
||||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||||
# Require a more recent version.
|
# Require a more recent version.
|
||||||
git_override(
|
git_override(
|
||||||
module_name = "highway",
|
module_name = "highway",
|
||||||
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34",
|
commit = "92d327e841d78e11ae888757a3e16d291951cf64",
|
||||||
remote = "https://github.com/google/highway",
|
remote = "https://github.com/google/highway",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
||||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||||
FetchContent_MakeAvailable(gemma)
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,9 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
|
||||||
Softmax(logits, vocab_size, /*worker=*/0);
|
hwy::Profiler& p) {
|
||||||
|
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
|
|
@ -109,7 +110,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
const SampleFunc sample_token = [&](float* probs,
|
const SampleFunc sample_token = [&](float* probs,
|
||||||
size_t vocab_size) -> TokenAndProb {
|
size_t vocab_size) -> TokenAndProb {
|
||||||
// input is logits, not yet probabilities
|
// input is logits, not yet probabilities
|
||||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
|
||||||
// We are called for each token, but pos starts at 1. Clamping
|
// We are called for each token, but pos starts at 1. Clamping
|
||||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||||
HWY_ASSERT(pos < prompt.size());
|
HWY_ASSERT(pos < prompt.size());
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
#include "util/threading_context.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -53,8 +54,9 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len,
|
||||||
const float* HWY_RESTRICT q,
|
const float* HWY_RESTRICT q,
|
||||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||||
const size_t worker) {
|
hwy::Profiler& p, const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
static const auto zone = p.AddZone("Gen.Attention.QDotK");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
|
|
@ -73,8 +75,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
const size_t worker, const size_t pos,
|
hwy::Profiler& p, const size_t worker,
|
||||||
const float mul = 1.0f) {
|
const size_t pos, const float mul = 1.0f) {
|
||||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
|
|
@ -86,10 +88,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
}
|
}
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
if (post_qk == PostQKType::HalfRope) {
|
if (post_qk == PostQKType::HalfRope) {
|
||||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
|
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
|
||||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
|
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
|
||||||
} else {
|
} else {
|
||||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
|
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -97,26 +99,31 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
// `att_out`. Equivalent in gemma/modules.py:
|
// `att_out`. Equivalent in gemma/modules.py:
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
static HWY_INLINE void WeightedSumV(
|
static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
const size_t start_pos, const size_t last_pos,
|
const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
const hwy::Divisor& div_seq_len,
|
||||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
const float* HWY_RESTRICT att,
|
||||||
|
const MatPtrT<KV_t>& v,
|
||||||
|
float* HWY_RESTRICT att_out,
|
||||||
|
hwy::Profiler& p, const size_t worker) {
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||||
// we supported non-transposed B.
|
// we supported non-transposed B.
|
||||||
// TODO: 2..4x unroll
|
// TODO: 2..4x unroll
|
||||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
|
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
|
||||||
|
worker);
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
|
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
{
|
{
|
||||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||||
}
|
}
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
|
||||||
|
worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +135,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att_out, const size_t worker) {
|
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const size_t seq_len =
|
const size_t seq_len =
|
||||||
|
|
@ -138,21 +145,21 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
||||||
layer.layer_config.qkv_dim, worker);
|
layer.layer_config.qkv_dim, p, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
|
PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
|
||||||
query_scale);
|
query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||||
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
|
||||||
Softmax(att, att_len, worker, /*temperature=*/1.0f);
|
Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
|
||||||
worker);
|
worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -167,9 +174,8 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivations& activations, QBatch& qbatch,
|
||||||
NestedPools& pools) {
|
ThreadingContext& ctx) {
|
||||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
|
||||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
|
||||||
|
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
@ -189,9 +195,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
const size_t tq_idx = activations.div_heads.Divide(task);
|
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||||
const size_t head = activations.div_heads.Remainder(task);
|
const size_t head = activations.div_heads.Remainder(task);
|
||||||
#if PROFILER_ENABLED
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
const hwy::Zone zone(worker, zone_id_par);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||||
|
|
@ -223,14 +227,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||||
layer, activations, att, att_out, worker);
|
layer, activations, att, att_out, ctx.profiler,
|
||||||
|
worker);
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||||
pools, func);
|
ctx.pools, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -303,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
||||||
thread);
|
env.ctx.profiler, thread);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||||
pos);
|
env.ctx.profiler, thread, pos);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
@ -339,6 +344,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env, int flags) {
|
MatMulEnv& env, int flags) {
|
||||||
|
static const auto zone =
|
||||||
|
env.ctx.profiler.AddZone("Gen.Attention", ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
|
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||||
|
|
@ -347,7 +356,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||||
env.ctx.pools);
|
env.ctx);
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,12 +33,12 @@ namespace gcpp {
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out, size_t worker); \
|
float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const LayerWeightsPtrs& layer, \
|
||||||
AttentionActivations& activations, \
|
AttentionActivations& activations, \
|
||||||
QBatch& qbatch, NestedPools& pools); \
|
QBatch& qbatch, ThreadingContext& ctx); \
|
||||||
\
|
\
|
||||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const LayerWeightsPtrs& layer, \
|
||||||
|
|
|
||||||
|
|
@ -45,9 +45,10 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||||
const T* HWY_RESTRICT c2, const size_t count,
|
const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "Gen.Activation");
|
static const auto zone = p.AddZone("Gen.Activation");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<T>;
|
using DF = hn::ScalableTag<T>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -64,28 +65,30 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||||
|
|
||||||
// No C2 multiplier.
|
// No C2 multiplier.
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
void ActivationBatched(ActivationType activation, Mat& c1,
|
||||||
|
ThreadingContext& ctx) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||||
// Cast to correct type so type deduction works.
|
// Cast to correct type so type deduction works.
|
||||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||||
c1.Cols(), worker);
|
c1.Cols(), ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
||||||
const Mat* c2, NestedPools& pools) {
|
const Mat* c2, ThreadingContext& ctx) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker);
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||||
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
} else { // No multiplier
|
} else { // No multiplier
|
||||||
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) {
|
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||||
c1.Cols(), worker);
|
c1.Cols(), ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -110,7 +113,9 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
||||||
|
|
||||||
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||||
Activations& activations, MatMulEnv& env) {
|
Activations& activations, MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.FFW");
|
static const auto zone =
|
||||||
|
env.ctx.profiler.AddZone("Gen.FFW", ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
|
|
||||||
|
|
@ -129,7 +134,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||||
env.ctx.pools);
|
env.ctx);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
|
#include "util/basics.h" // PROFILER_ZONE3
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -138,7 +139,8 @@ static float EmbeddingScaling(size_t model_dim) {
|
||||||
static HWY_NOINLINE size_t
|
static HWY_NOINLINE size_t
|
||||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||||
|
const ImageTokens* image_tokens = nullptr,
|
||||||
size_t image_token_position = 0) {
|
size_t image_token_position = 0) {
|
||||||
// Image tokens just need to be copied.
|
// Image tokens just need to be copied.
|
||||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||||
|
|
@ -174,7 +176,8 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||||
model_dim);
|
model_dim);
|
||||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
|
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
|
||||||
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (model_config.absolute_pe) {
|
if (model_config.absolute_pe) {
|
||||||
|
|
@ -249,7 +252,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||||
image_token_position = EmbedMMToken(
|
image_token_position = EmbedMMToken(
|
||||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||||
runtime_config.image_tokens, image_token_position);
|
env.ctx, runtime_config.image_tokens, image_token_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transformer with one batch of tokens from a single query.
|
// Transformer with one batch of tokens from a single query.
|
||||||
|
|
@ -306,7 +309,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
// TODO: parallelize?
|
// TODO: parallelize?
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
/*pos_in_prompt=*/0, config, weights, activations.x, env.ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||||
|
|
@ -419,7 +422,8 @@ static void DecodeStepT(const ModelConfig& config,
|
||||||
const size_t worker = 0; // TODO: parallelize
|
const size_t worker = 0; // TODO: parallelize
|
||||||
non_eos.Foreach([&](size_t qi) {
|
non_eos.Foreach([&](size_t qi) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
|
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
|
||||||
|
env.ctx.profiler, worker);
|
||||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||||
timing_info.NotifyGenerated();
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
|
|
@ -429,27 +433,28 @@ static void DecodeStepT(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE SampleFunc
|
static HWY_INLINE SampleFunc
|
||||||
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
|
||||||
// If user provided a sample_func, use it.
|
// If user provided a sample_func, use it.
|
||||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||||
|
|
||||||
|
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||||
const size_t worker = 0; // TODO: parallelize
|
const size_t worker = 0; // TODO: parallelize
|
||||||
|
|
||||||
// Fast path for top-1 with no accept_token.
|
// Fast path for top-1 with no accept_token.
|
||||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE2(worker, "Gen.Sample Top1");
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
return Top1OfSoftmax(logits, vocab_size);
|
return Top1OfSoftmax(logits, vocab_size);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// General case: Softmax with top-k sampling.
|
// General case: Softmax with top-k sampling.
|
||||||
return [&runtime_config](float* logits,
|
return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
|
||||||
PROFILER_ZONE("Gen.Sample general");
|
PROFILER_ZONE("Gen.Sample general");
|
||||||
return FusedSoftmaxAndSampleTopK(
|
return FusedSoftmaxAndSampleTopK(
|
||||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||||
runtime_config.temperature, runtime_config.accept_token, worker);
|
runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
|
||||||
|
worker);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -524,7 +529,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
max_gen_steps = seq_len - max_prompt_size;
|
max_gen_steps = seq_len - max_prompt_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
|
||||||
|
|
||||||
{
|
{
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h" // PromptWrapping
|
#include "gemma/configs.h" // PromptWrapping
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
|
|
|
||||||
18
gemma/vit.cc
18
gemma/vit.cc
|
|
@ -95,7 +95,7 @@ class VitAttention {
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||||
MulByConst(query_scale, q, qkv_dim, worker);
|
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -111,7 +111,7 @@ class VitAttention {
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||||
float* HWY_RESTRICT c = C.Row(task);
|
float* HWY_RESTRICT c = C.Row(task);
|
||||||
Softmax(c, C.Cols(), worker);
|
Softmax(c, C.Cols(), env_.ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||||
|
|
@ -121,7 +121,8 @@ class VitAttention {
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker);
|
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
|
||||||
|
env_.ctx.profiler, worker);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -144,7 +145,7 @@ class VitAttention {
|
||||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||||
MulByConst(query_scale, q, qkv_dim, worker);
|
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations_.attention.att.Row(token) + head * seq_len;
|
activations_.attention.att.Row(token) + head * seq_len;
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
|
|
@ -153,7 +154,7 @@ class VitAttention {
|
||||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||||
}
|
}
|
||||||
// SoftMax yields "probabilities" in head_att.
|
// SoftMax yields "probabilities" in head_att.
|
||||||
Softmax(head_att, seq_len, worker);
|
Softmax(head_att, seq_len, env_.ctx.profiler, worker);
|
||||||
// Compute weighted sum of v into att_out.
|
// Compute weighted sum of v into att_out.
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||||
|
|
@ -161,7 +162,8 @@ class VitAttention {
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker);
|
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
|
||||||
|
env_.ctx.profiler, worker);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -224,7 +226,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
activations.C1);
|
activations.C1);
|
||||||
|
|
||||||
// Activation (Gelu), store in C1.
|
// Activation (Gelu), store in C1.
|
||||||
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools);
|
ActivationBatched(layer_config.activation, activations.C1, env.ctx);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
||||||
|
|
@ -334,7 +336,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
// Apply soft embedding norm before input projection.
|
// Apply soft embedding norm before input projection.
|
||||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
||||||
vit_model_dim, /*worker=*/0);
|
vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -381,9 +381,11 @@ static void DecompressToBF16(MatPtr& mat,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
const BlobReader& reader, hwy::ThreadPool& pool) {
|
const BlobReader& reader, ThreadingContext& ctx) {
|
||||||
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
static const auto zone =
|
||||||
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16");
|
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
||||||
|
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||||
|
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||||
const TensorToRead& tensor = tensors[task];
|
const TensorToRead& tensor = tensors[task];
|
||||||
MatPtr& mat = *tensor.mat;
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
||||||
|
|
@ -460,10 +462,11 @@ static std::vector<IOBatch> MakeBatches(
|
||||||
// want to use the OS cache between consecutive runs.
|
// want to use the OS cache between consecutive runs.
|
||||||
static void ReadBatches(const BlobReader& reader,
|
static void ReadBatches(const BlobReader& reader,
|
||||||
const std::vector<IOBatch>& batches,
|
const std::vector<IOBatch>& batches,
|
||||||
hwy::ThreadPool& pool) {
|
ThreadingContext& ctx) {
|
||||||
|
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
||||||
// >5x speedup from parallel reads when cached.
|
// >5x speedup from parallel reads when cached.
|
||||||
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||||
PROFILER_ZONE2(thread, "Startup.Weights.Read");
|
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||||
const IOBatch& batch = batches[i];
|
const IOBatch& batch = batches[i];
|
||||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||||
const uint64_t bytes_read = batch.Read(reader.file());
|
const uint64_t bytes_read = batch.Read(reader.file());
|
||||||
|
|
@ -500,16 +503,14 @@ static MapPtr MapOrReadAll(std::vector<TensorToRead>& tensors,
|
||||||
AllocateAndBindAll(tensors, *mode, mat_owners, ctx);
|
AllocateAndBindAll(tensors, *mode, mat_owners, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool& pool = ctx.pools.Pool();
|
|
||||||
|
|
||||||
if (*mode == WeightsPtrs::Mode::kReadBF16) {
|
if (*mode == WeightsPtrs::Mode::kReadBF16) {
|
||||||
ReadAllToBF16(tensors, reader, pool);
|
ReadAllToBF16(tensors, reader, ctx);
|
||||||
return MapPtr();
|
return MapPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<IOBatch> batches =
|
const std::vector<IOBatch> batches =
|
||||||
MakeBatches(tensors, reader.file_bytes());
|
MakeBatches(tensors, reader.file_bytes());
|
||||||
ReadBatches(reader, batches, pool);
|
ReadBatches(reader, batches, ctx);
|
||||||
return MapPtr();
|
return MapPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -519,7 +520,7 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||||
const InferenceArgs& inference,
|
const InferenceArgs& inference,
|
||||||
std::vector<MatOwner>& mat_owners,
|
std::vector<MatOwner>& mat_owners,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
PROFILER_ZONE("Startup.ReadFromBlobs");
|
PROFILER_ZONE("Startup.Weights.ReadFromBlobs");
|
||||||
|
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<TensorToRead> tensors;
|
std::vector<TensorToRead> tensors;
|
||||||
|
|
|
||||||
17
ops/matmul.h
17
ops/matmul.h
|
|
@ -720,32 +720,33 @@ struct MMArgs {
|
||||||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||||
#if PROFILER_ENABLED
|
#if PROFILER_ENABLED
|
||||||
class MMZone {
|
class MMZone {
|
||||||
using Zone = hwy::Zone;
|
using Zone = hwy::profiler::Zone;
|
||||||
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8);
|
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 16);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
~MMZone() {
|
~MMZone() {
|
||||||
if (used_) {
|
if (data_ != 0) {
|
||||||
Zone* zone = reinterpret_cast<Zone*>(&data_);
|
Zone* zone = reinterpret_cast<Zone*>(&data_);
|
||||||
zone->~Zone();
|
zone->~Zone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// `name` must be a string literal.
|
// `name` must be a string literal.
|
||||||
void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) {
|
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
|
||||||
|
const MMArgs& args) {
|
||||||
if (args.per_key->WantProfile()) {
|
if (args.per_key->WantProfile()) {
|
||||||
new (&data_) Zone(thread_id, zone_id);
|
new (&data_) Zone(args.env->ctx.profiler, thread, zone);
|
||||||
used_ = true;
|
HWY_DASSERT(data_ != 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t data_ = 0;
|
uint64_t data_ = 0;
|
||||||
bool used_ = false;
|
uint64_t data2_ = 0;
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
struct MMZone {
|
struct MMZone {
|
||||||
void MaybeEnter(size_t, uint32_t, const MMArgs&) {}
|
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
|
||||||
};
|
};
|
||||||
#endif // PROFILER_ENABLED
|
#endif // PROFILER_ENABLED
|
||||||
|
|
||||||
|
|
|
||||||
128
ops/ops-inl.h
128
ops/ops-inl.h
|
|
@ -125,9 +125,9 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||||
return hn::Mul(v, cdf);
|
return hn::Mul(v, cdf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Activation already has a profiler zone.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
PROFILER_ZONE("ops.Gelu");
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
hn::Transform(D(), x, size,
|
hn::Transform(D(), x, size,
|
||||||
|
|
@ -191,9 +191,10 @@ namespace detail {
|
||||||
|
|
||||||
// Shared by RMSNorm and RMSNormInplace.
|
// Shared by RMSNorm and RMSNormInplace.
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||||
const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.RMSNormMul");
|
static const auto zone = p.AddZone("Ops.RMSNormMul");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||||
|
|
@ -205,18 +206,20 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
||||||
|
|
||||||
// `x_ofs` is the offset within `x`, required for NuqStream.
|
// `x_ofs` is the offset within `x`, required for NuqStream.
|
||||||
template <typename XT, typename WT, typename OT>
|
template <typename XT, typename WT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
|
const WT* HWY_RESTRICT weight,
|
||||||
OT* HWY_RESTRICT out, const size_t size,
|
size_t w_ofs, OT* HWY_RESTRICT out,
|
||||||
const size_t HWY_MAYBE_UNUSED worker) {
|
const size_t size, hwy::Profiler& p,
|
||||||
PROFILER_ZONE2(worker, "ops.RMSNorm");
|
const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Ops.RMSNorm");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker));
|
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker));
|
||||||
|
|
||||||
const auto packed_x = MakeSpan(x, size);
|
const auto packed_x = MakeSpan(x, size);
|
||||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||||
|
|
@ -240,15 +243,16 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
template <typename WT, typename XT>
|
template <typename WT, typename XT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
|
static const auto zone = p.AddZone("Ops.RMSNormInplace");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker));
|
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker));
|
||||||
|
|
||||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||||
const auto packed_x = MakeSpan(inout, size);
|
const auto packed_x = MakeSpan(inout, size);
|
||||||
|
|
@ -407,9 +411,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.Rope");
|
static const auto zone = p.AddZone("Ops.Rope");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -466,9 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos,
|
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
|
||||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.RopeAndMulBy");
|
static const auto zone = p.AddZone("Ops.RopeAndMulBy");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -525,10 +531,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
|
float* HWY_RESTRICT out,
|
||||||
const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t size,
|
||||||
PROFILER_ZONE2(worker, "ops.AddFrom");
|
hwy::Profiler& p,
|
||||||
|
const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Ops.AddFrom");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -576,12 +585,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
HWY_DASSERT(activations.SameShape(out));
|
HWY_DASSERT(activations.SameShape(out));
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
SmallParallelFor(activations.Rows(), ctx.pools,
|
SmallParallelFor(
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNorm(activations.Row(token_idx),
|
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
|
||||||
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
||||||
activations.Cols(), worker);
|
});
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -592,12 +600,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
SmallParallelFor(inout.Rows(), ctx.pools,
|
SmallParallelFor(
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
|
||||||
inout.Row(token_idx), inout.Cols(),
|
inout.Cols(), ctx.profiler, worker);
|
||||||
worker);
|
});
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -622,17 +629,20 @@ template <typename XT>
|
||||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
HWY_DASSERT(out.SameShape(x));
|
HWY_DASSERT(out.SameShape(x));
|
||||||
SmallParallelFor(
|
SmallParallelFor(out.Rows(), ctx.pools,
|
||||||
out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||||
});
|
ctx.profiler, worker);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||||
const float c, XT* HWY_RESTRICT x, const size_t size,
|
const size_t size,
|
||||||
const HWY_MAYBE_UNUSED size_t worker) {
|
hwy::Profiler& p,
|
||||||
PROFILER_ZONE2(worker, "ops.MulByConst");
|
const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Ops.MulByConst");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -672,8 +682,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
||||||
template <typename XT, typename OT>
|
template <typename XT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.MulByConstTo");
|
static const auto zone = p.AddZone("Ops.MulByConstTo");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -714,8 +725,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||||
template <typename XT, typename OT>
|
template <typename XT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
|
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -760,9 +772,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t worker,
|
hwy::Profiler& p, const size_t worker,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
PROFILER_ZONE2(worker, "ops.Softmax");
|
static const auto zone = p.AddZone("Ops.Softmax");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
@ -803,7 +816,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const float sum_exp = Sum(d, x, size);
|
const float sum_exp = Sum(d, x, size);
|
||||||
// Double-precision reciprocal does not appear to affect the results.
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / sum_exp;
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, x, size, worker);
|
MulByConst(mul, x, size, p, worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||||
|
|
@ -893,9 +906,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
const size_t size,
|
const size_t size, hwy::Profiler& p,
|
||||||
const HWY_MAYBE_UNUSED size_t worker) {
|
const size_t worker) {
|
||||||
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
|
static const auto zone = p.AddZone("Ops.LogitsSoftCap");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -911,10 +925,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
|
|
||||||
// Calls LogitsSoftCap if cap != 0.0f.
|
// Calls LogitsSoftCap if cap != 0.0f.
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
const float cap, float* HWY_RESTRICT x, const size_t size,
|
const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
if (cap != 0.0f) {
|
if (cap != 0.0f) {
|
||||||
LogitsSoftCap(cap, x, size, worker);
|
LogitsSoftCap(cap, x, size, p, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -998,7 +1012,7 @@ template <typename TAcceptToken>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
||||||
size_t worker) {
|
hwy::Profiler& p, size_t worker) {
|
||||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||||
// avoids computing the softmax of all logits.
|
// avoids computing the softmax of all logits.
|
||||||
|
|
@ -1012,7 +1026,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t mask = token_logits.size();
|
size_t mask = token_logits.size();
|
||||||
Softmax(topk_logits.data(), mask, worker, temperature);
|
Softmax(topk_logits.data(), mask, p, worker, temperature);
|
||||||
auto distribution = std::discrete_distribution<int>(
|
auto distribution = std::discrete_distribution<int>(
|
||||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||||
int topk_sampled_index = distribution(gen);
|
int topk_sampled_index = distribution(gen);
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,6 @@
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
||||||
#include "ops/ops.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
@ -33,11 +31,13 @@
|
||||||
|
|
||||||
#include "gemma/activations.h" // ChooseQueryScale
|
#include "gemma/activations.h" // ChooseQueryScale
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "ops/ops.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h" // MatStorageT
|
#include "util/mat.h" // MatStorageT
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
@ -166,7 +166,7 @@ struct TestAddFrom {
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleAddFrom(o, e, count);
|
SimpleAddFrom(o, e, count);
|
||||||
AddFrom(o, x, count, /*worker=*/0);
|
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
|
|
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||||
MulByConstAndAdd(constant, o, x, count, /*worker=*/0);
|
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
|
|
@ -229,7 +229,7 @@ struct TestMulByConst {
|
||||||
T constant = Random<T>(rng);
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
SimpleMulByConst(constant, e, count);
|
SimpleMulByConst(constant, e, count);
|
||||||
MulByConst(constant, x, count, /*worker=*/0);
|
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
|
|
@ -259,7 +259,7 @@ struct TestSoftmax {
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleSoftmax(e, count);
|
SimpleSoftmax(e, count);
|
||||||
Softmax(x, count, /*worker=*/0);
|
Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
T sum = 0.0f;
|
T sum = 0.0f;
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
|
@ -349,6 +349,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
ThreadingContext ctx(threading_args);
|
ThreadingContext ctx(threading_args);
|
||||||
|
hwy::Profiler& p = ctx.profiler;
|
||||||
|
const size_t worker = 0;
|
||||||
|
|
||||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||||
ChooseWrapping(Model::GEMMA2_9B));
|
ChooseWrapping(Model::GEMMA2_9B));
|
||||||
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
const size_t dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
|
|
@ -381,7 +384,8 @@ void TestRopeAndMulBy() {
|
||||||
CopyMat(x, qactual);
|
CopyMat(x, qactual);
|
||||||
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
pos);
|
pos);
|
||||||
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||||
|
worker);
|
||||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
}
|
}
|
||||||
|
|
@ -391,7 +395,7 @@ void TestRopeAndMulBy() {
|
||||||
CopyMat(x, qactual);
|
CopyMat(x, qactual);
|
||||||
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
pos);
|
pos);
|
||||||
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
}
|
}
|
||||||
|
|
@ -402,9 +406,10 @@ void TestRopeAndMulBy() {
|
||||||
CopyMat(x, kactual2);
|
CopyMat(x, kactual2);
|
||||||
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
|
||||||
pos);
|
pos);
|
||||||
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
|
||||||
|
worker);
|
||||||
static_assert(kmul == 1.0f, "");
|
static_assert(kmul == 1.0f, "");
|
||||||
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos);
|
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
|
||||||
|
|
||||||
for (size_t i = 0; i < dim_qkv; ++i) {
|
for (size_t i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
|
||||||
|
|
@ -454,7 +459,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||||
RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0);
|
RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
|
||||||
|
|
||||||
for (size_t i = 0; i < kSize; i++) {
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||||
|
|
@ -580,11 +585,13 @@ void TestAllLayerNorm() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSampleTopK() {
|
void TestSampleTopK() {
|
||||||
|
hwy::Profiler& p = hwy::Profiler::Get();
|
||||||
|
const size_t worker = 0;
|
||||||
const size_t kSize = 52;
|
const size_t kSize = 52;
|
||||||
std::vector<float> logits(kSize);
|
std::vector<float> logits(kSize);
|
||||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
Softmax(logits.data(), kSize, p, worker);
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
gen.seed(0x12345678);
|
gen.seed(0x12345678);
|
||||||
float temperature = 1.0f;
|
float temperature = 1.0f;
|
||||||
|
|
@ -600,7 +607,7 @@ void TestSampleTopK() {
|
||||||
EXPECT_EQ(sample, 50); // Last even index.
|
EXPECT_EQ(sample, 50); // Last even index.
|
||||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
Softmax(logits.data(), kSize, p, worker);
|
||||||
// Sample from the top 3, expect one of the top 3 even indices.
|
// Sample from the top 3, expect one of the top 3 even indices.
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/contrib/thread_pool/topology.h"
|
#include "hwy/contrib/thread_pool/topology.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -171,6 +172,8 @@ NestedPools::NestedPools(const BoundedTopology& topology,
|
||||||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||||
|
|
||||||
|
hwy::Profiler::Get().SetMaxThreads(MaxWorkers());
|
||||||
}
|
}
|
||||||
|
|
||||||
// `max_or_zero` == 0 means no limit.
|
// `max_or_zero` == 0 means no limit.
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,8 @@ static void TunePool(hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
||||||
: topology(BoundedSlice(args.skip_packages, args.max_packages),
|
: profiler(hwy::Profiler::Get()),
|
||||||
|
topology(BoundedSlice(args.skip_packages, args.max_packages),
|
||||||
BoundedSlice(args.skip_clusters, args.max_clusters),
|
BoundedSlice(args.skip_clusters, args.max_clusters),
|
||||||
BoundedSlice(args.skip_lps, args.max_lps)),
|
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||||
allocator(topology, args.bind != Tristate::kFalse),
|
allocator(topology, args.bind != Tristate::kFalse),
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ struct ThreadingContext {
|
||||||
// Expected to be called early in the program, before threading starts.
|
// Expected to be called early in the program, before threading starts.
|
||||||
explicit ThreadingContext(const ThreadingArgs& args);
|
explicit ThreadingContext(const ThreadingArgs& args);
|
||||||
|
|
||||||
|
hwy::Profiler& profiler;
|
||||||
BoundedTopology topology;
|
BoundedTopology topology;
|
||||||
Allocator allocator;
|
Allocator allocator;
|
||||||
NestedPools pools;
|
NestedPools pools;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue