1.1x prefill and decode speedup (attention/activations)

Optimizations
- Better load-balancing in attention threading
(Previously, clusters were limited by #heads)
- Add MulByConstTo to avoid zero-init
- Parallel activations

Cleanup
- Prepare for RowPtr in A or B
- Pass through thread_id to ops
- Avoid warning in bench_matmul

PiperOrigin-RevId: 773723423
This commit is contained in:
Jan Wassenberg 2025-06-20 08:59:23 -07:00 committed by Copybara-Service
parent 7630ec0c92
commit 0f70f285e0
12 changed files with 266 additions and 191 deletions

View File

@ -18,6 +18,7 @@
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>
#include <atomic>
#include <vector>
@ -99,6 +100,7 @@ struct AttentionActivations {
1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
@ -125,10 +127,6 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size);
}
bool IsGlobalLayer(size_t layer_idx) const {
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
}
const ModelConfig& config;
MatStorageT<float> q; // query
@ -144,6 +142,8 @@ struct AttentionActivations {
MatStorageT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
// Unfortunately, some models (Griffin) have non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};

View File

@ -52,7 +52,9 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) {
const size_t worker, const size_t pos,
const float mul = 1.0f) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
@ -83,10 +86,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
}
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
}
}
@ -94,39 +97,38 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = v.Cols();
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
// TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
}
}
// Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
const size_t last_pos, float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q,
layer.layer_config.qkv_dim);
layer.layer_config.qkv_dim, worker);
});
}
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len);
Softmax(att, att_len);
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
Softmax(att, att_len, /*temperature=*/1.0f, worker);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
worker);
}
// The attention window usually starts at 0 unless `pos` is larger than
@ -179,22 +183,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
// For each head/token/query, compute Q.K, softmax, and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task);
#if PROFILER_ENABLED
const hwy::Zone zone(worker, zone_id_par);
#endif
// Statically partition token/query across packages.
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
const IndexRangePartition tq_ranges =
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
ParallelizeOneRange(
tq_ranges, pools.AllPackages(),
[&](const IndexRange& tq_range, const size_t pkg_idx) {
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
pools.AllClusters(pkg_idx).Run(
tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) {
const HWY_MAYBE_UNUSED size_t cluster_base =
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;
@ -202,8 +201,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos =
StartPos(pos, activations.config, layer_idx);
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
@ -211,43 +209,26 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
last_pos = prefix_end - 1;
}
pools.Cluster(pkg_idx, cluster_idx)
.Run(
0, layer_config.heads,
[&](const size_t head, size_t thread) HWY_ATTR {
#if PROFILER_ENABLED
const hwy::Zone zone(cluster_base + thread,
zone_id_par);
#endif
const size_t head_offset =
(head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q =
activations.q.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT att =
activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;
// Make strided read-only views into the kv cache for
// this query and head.
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
kv_cache.Stride());
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
kv_cache.Stride());
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
k, v, layer_idx, layer,
activations, att, att_out);
});
});
});
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, worker);
};
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools,
/*pkg_idx=*/0, func);
}
// Different functions use different naming conventions for the number of
@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// tasks are very lightweight.
env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
[&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
thread);
});
}
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
pos);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});

View File

@ -33,7 +33,7 @@ namespace gcpp {
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
float* HWY_RESTRICT att_out, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \

View File

@ -23,6 +23,7 @@
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/mat.h"
#include "util/threading.h"
#include "hwy/profiler.h"
// Include guard (still compiled once per target)
@ -43,9 +44,10 @@ namespace gcpp {
namespace HWY_NAMESPACE {
template <typename T>
HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
const T* HWY_RESTRICT c2, size_t count) {
PROFILER_ZONE("Gen.Activation");
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
const T* HWY_RESTRICT c2, const size_t count,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Activation");
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>;
using VF = hn::Vec<DF>;
@ -62,29 +64,33 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
// No C2 multiplier.
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1) {
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
using T = typename Mat::T;
for (size_t i = 0; i < c1.Rows(); ++i) {
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
[&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
Activation(activation, c1.Row(task),
static_cast<const T*>(nullptr), c1.Cols(), worker);
});
}
template <class Mat>
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
const Mat* c2) {
const Mat* c2, NestedPools& pools) {
using T = typename Mat::T;
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), c2->Row(i), c1.Cols());
}
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
[&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
worker);
});
} else { // No multiplier
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
[&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task),
static_cast<const T*>(nullptr), c1.Cols(), worker);
});
}
}
@ -126,7 +132,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_config.activation, activations.C1, &activations.C2);
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
env.ctx.pools);
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, output_bias, env,

View File

@ -224,7 +224,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
activations.C1);
// Activation (Gelu), store in C1.
ActivationBatched(layer_config.activation, activations.C1);
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools);
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,

View File

@ -108,6 +108,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// BindB/C if there is a single package: they will be a no-op.
BindB(b_trans, sizeof(TC), env.parallel);
BindC(C, env.parallel);
C.AllocateAndAttachRowPtrs(env.row_ptrs);
Tristate use_spinning = Tristate::kDefault;
env.ctx.pools.MaybeStartSpinning(use_spinning);
@ -139,8 +140,8 @@ using SFP = SfpStream;
void BenchAllMatMul() {
if (first_target == 0) first_target = HWY_TARGET;
// Disable the best-target-only limitation.
// if (HWY_TARGET != first_target) return;
// Comment out to disable the best-target-only limitation.
if (HWY_TARGET != first_target) return;
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||

View File

@ -1320,20 +1320,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
RowPtrs<TC> C_rows(C.GetRowPtrs());
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr,
"MatMul perf warning: setting row pointers because "
"%s.AttachRowPtrs() was not called.\n",
C.Name());
}
HWY_DASSERT(C.HasPtr());
for (size_t r = 0; r < C.Rows(); ++r) {
env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
}
C_rows = RowPtrs<TC>(env.row_ptrs[0].get());
}
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]);
const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Rows();

View File

@ -428,7 +428,9 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM));
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // A
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxN)); // B
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
}
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {

View File

@ -682,10 +682,11 @@ struct MatMulEnv {
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
// writes to differing KV positions per query / output row.
// The first entry is sufficient for any MatMul, but also potentially
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
// and not overwritten. Per-tensor allocations make it likelier that asan
// detects bugs such as use after free, overrun, and dangling references.
// The first three allocations are sufficient for any A, B, C, respectively,
// but also potentially overwritten by each MatMul. Subsequent entries are
// precomputed for tensors and not overwritten. Per-tensor allocations make
// it likelier that asan detects bugs such as use after free, overrun, and
// dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};

View File

@ -191,8 +191,9 @@ namespace detail {
// Shared by RMSNorm and RMSNormInplace.
template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
PROFILER_ZONE("ops.RMSNormMul");
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.RMSNormMul");
const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
@ -204,18 +205,18 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
// `x_ofs` is the offset within `x`, required for NuqStream.
template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const WT* HWY_RESTRICT weight,
size_t w_ofs, OT* HWY_RESTRICT out,
const size_t size) {
PROFILER_ZONE("ops.RMSNorm");
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
OT* HWY_RESTRICT out, const size_t size,
const size_t HWY_MAYBE_UNUSED worker = 0) {
PROFILER_ZONE2(worker, "ops.RMSNorm");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker));
const auto packed_x = MakeSpan(x, size);
const auto packed_w = MakeSpan(weight, w_ofs + size);
@ -237,18 +238,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
size_t w_ofs,
XT* HWY_RESTRICT inout,
const size_t size) {
PROFILER_ZONE("ops.RMSNormInplace");
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker));
const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_x = MakeSpan(inout, size);
@ -407,8 +407,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
// This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.Rope");
const float* HWY_RESTRICT inv_timescale, const int pos,
const size_t HWY_MAYBE_UNUSED worker = 0) {
PROFILER_ZONE2(worker, "ops.Rope");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -465,8 +466,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.RopeAndMulBy");
const float* HWY_RESTRICT inv_timescale, const int pos,
const size_t HWY_MAYBE_UNUSED worker = 0) {
PROFILER_ZONE2(worker, "ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -523,10 +525,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
}
template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
float* HWY_RESTRICT out,
const size_t size) {
PROFILER_ZONE("ops.AddFrom");
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.AddFrom");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
@ -621,9 +623,10 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
}
template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
size_t size) {
PROFILER_ZONE("ops.MulByConst");
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
const float c, XT* HWY_RESTRICT x, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.MulByConst");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
@ -659,12 +662,54 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
}
}
// Same as above, but without a separate output. Same as below without the add.
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
const XT* HWY_RESTRICT x,
OT* HWY_RESTRICT out,
size_t size) {
PROFILER_ZONE("ops.MulByConstAndAdd");
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.MulByConstTo");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const VF v_c = hn::Set(df, c);
const auto packed_x = MakeSpan(x, size);
const auto packed_out = MakeSpan(out, size);
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1;
Decompress2(df, packed_x, i, x0, x1);
const VF out0 = hn::Mul(x0, v_c);
const VF out1 = hn::Mul(x1, v_c);
Compress2(df, out0, out1, packed_out, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
const VF out0 = hn::Mul(x0, v_c);
const VF out1 = hn::Mul(x1, v_c);
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
}
}
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
@ -709,8 +754,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
// See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
float temperature = 1.0f) {
PROFILER_ZONE("ops.Softmax");
float temperature = 1.0f,
const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.Softmax");
HWY_DASSERT(size != 0);
namespace hn = hwy::HWY_NAMESPACE;
@ -840,11 +886,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
return TokenAndProb{.token = argmax.token, .prob = prob};
}
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size,
const size_t max_pos) {
PROFILER_ZONE("ops.LogitsSoftCap");
HWY_DASSERT(max_pos <= size);
static HWY_NOINLINE void LogitsSoftCap(
const float cap, float* HWY_RESTRICT x, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) {
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
@ -852,22 +897,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const float inv_cap = 1.0f / cap;
hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR {
hn::Transform(D(), x, size, [cap, inv_cap](D d, V v) HWY_ATTR {
return hn::Mul(hn::Set(d, cap),
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
});
}
static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size) {
LogitsSoftCap(cap, x, size, size);
}
// Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, float* HWY_RESTRICT x, const size_t size) {
const float cap, float* HWY_RESTRICT x, const size_t size,
const size_t worker = 0) {
if (cap != 0.0f) {
LogitsSoftCap(cap, x, size, size);
LogitsSoftCap(cap, x, size, worker);
}
}

View File

@ -34,7 +34,7 @@
namespace gcpp {
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
// for C, in future also for A.
// for C (KV output), in future also for A or even B.
template <typename T>
class RowPtrs {
public:
@ -317,6 +317,25 @@ class MatPtrT : public MatPtr {
}
};
template <typename T>
RowPtrs<T> GetOrSetTempRowPtrs(
const MatPtrT<T>& mat,
const hwy::AlignedFreeUniquePtr<uint8_t*[]>& storage) {
if (HWY_LIKELY(mat.GetRowPtrs())) return RowPtrs<T>(mat.GetRowPtrs());
if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr,
"MatMul perf warning: setting row pointers because "
"%s.AttachRowPtrs() was not called.\n",
mat.Name());
}
HWY_DASSERT(mat.HasPtr());
for (size_t r = 0; r < mat.Rows(); ++r) {
storage[r] = reinterpret_cast<uint8_t*>(const_cast<T*>(mat.Row(r)));
}
return RowPtrs<T>(storage.get());
}
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
// types used as weights.
template <class Func, typename... Args>

View File

@ -321,6 +321,40 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
});
}
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
template <class Func>
void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
func(task, pkg_base + thread);
});
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(
ranges, all_clusters,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(),
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
});
});
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_