mirror of https://github.com/google/gemma.cpp.git
1.1x prefill and decode speedup (attention/activations)
Optimizations - Better load-balancing in attention threading (Previously, clusters were limited by #heads) - Add MulByConstTo to avoid zero-init - Parallel activations Cleanup - Prepare for RowPtr in A or B - Pass through thread_id to ops - Avoid warning in bench_matmul PiperOrigin-RevId: 773723423
This commit is contained in:
parent
7630ec0c92
commit
0f70f285e0
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <math.h> // sqrtf
|
#include <math.h> // sqrtf
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -99,6 +100,7 @@ struct AttentionActivations {
|
||||||
1000000.0)),
|
1000000.0)),
|
||||||
|
|
||||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||||
|
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
||||||
query_scale(ChooseQueryScale(config)) {
|
query_scale(ChooseQueryScale(config)) {
|
||||||
// Batch size can be 0 in experimental code so do not assert.
|
// Batch size can be 0 in experimental code so do not assert.
|
||||||
if (batch_size == 0) {
|
if (batch_size == 0) {
|
||||||
|
|
@ -125,10 +127,6 @@ struct AttentionActivations {
|
||||||
att_sums.OverrideRows(batch_size);
|
att_sums.OverrideRows(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsGlobalLayer(size_t layer_idx) const {
|
|
||||||
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
|
|
||||||
}
|
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
MatStorageT<float> q; // query
|
||||||
|
|
@ -144,6 +142,8 @@ struct AttentionActivations {
|
||||||
MatStorageT<float> inv_timescale_global;
|
MatStorageT<float> inv_timescale_global;
|
||||||
|
|
||||||
hwy::Divisor div_seq_len;
|
hwy::Divisor div_seq_len;
|
||||||
|
// Unfortunately, some models (Griffin) have non-power-of-two heads.
|
||||||
|
hwy::Divisor div_heads;
|
||||||
float query_scale;
|
float query_scale;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,9 @@ namespace HWY_NAMESPACE {
|
||||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len,
|
||||||
const float* HWY_RESTRICT q,
|
const float* HWY_RESTRICT q,
|
||||||
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
|
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
|
||||||
|
const size_t worker) {
|
||||||
|
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
|
|
@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
const size_t pos, const float mul = 1.0f) {
|
const size_t worker, const size_t pos,
|
||||||
|
const float mul = 1.0f) {
|
||||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
|
|
@ -83,10 +86,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
}
|
}
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
if (post_qk == PostQKType::HalfRope) {
|
if (post_qk == PostQKType::HalfRope) {
|
||||||
Rope(qk, qkv_dim / 2, inv_timescale, pos);
|
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
|
||||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
|
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
|
||||||
} else {
|
} else {
|
||||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
|
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -94,39 +97,38 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
// `att_out`. Equivalent in gemma/modules.py:
|
// `att_out`. Equivalent in gemma/modules.py:
|
||||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
static HWY_INLINE void WeightedSumV(
|
||||||
const size_t last_pos,
|
const size_t start_pos, const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||||
const float* HWY_RESTRICT att,
|
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||||
const MatPtrT<BF16>& v,
|
|
||||||
float* HWY_RESTRICT att_out) {
|
|
||||||
const size_t qkv_dim = v.Cols();
|
|
||||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
|
||||||
|
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
// we supported non-transposed B.
|
||||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
// TODO: 2..4x unroll
|
||||||
|
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
|
||||||
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
|
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
{
|
||||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||||
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
|
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||||
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
|
}
|
||||||
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
|
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||||
|
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculates the attention outputs for a single q, which may be updated
|
// Calculates the attention outputs for a single q, which may be updated
|
||||||
// in place for RMSNorm.
|
// in place for RMSNorm.
|
||||||
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t last_pos, float* HWY_RESTRICT q,
|
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||||
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
|
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
|
||||||
const size_t layer_idx,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
const LayerWeightsPtrs& layer,
|
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||||
const AttentionActivations& activations,
|
float* HWY_RESTRICT att_out, const size_t worker) {
|
||||||
float* HWY_RESTRICT att,
|
|
||||||
float* HWY_RESTRICT att_out) {
|
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const size_t seq_len =
|
const size_t seq_len =
|
||||||
|
|
@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
||||||
layer.layer_config.qkv_dim);
|
layer.layer_config.qkv_dim, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
|
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
|
||||||
|
query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||||
MaybeLogitsSoftCap(att_cap, att, att_len);
|
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
||||||
Softmax(att, att_len);
|
Softmax(att, att_len, /*temperature=*/1.0f, worker);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
|
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||||
|
worker);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The attention window usually starts at 0 unless `pos` is larger than
|
// The attention window usually starts at 0 unless `pos` is larger than
|
||||||
|
|
@ -179,22 +183,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||||
const size_t seq_len =
|
const size_t seq_len =
|
||||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
// All layers should have the same number of heads.
|
||||||
|
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
||||||
|
|
||||||
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||||
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
|
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||||
|
const size_t head = activations.div_heads.Remainder(task);
|
||||||
|
#if PROFILER_ENABLED
|
||||||
|
const hwy::Zone zone(worker, zone_id_par);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Statically partition token/query across packages.
|
|
||||||
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
|
|
||||||
const IndexRangePartition tq_ranges =
|
|
||||||
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
|
||||||
ParallelizeOneRange(
|
|
||||||
tq_ranges, pools.AllPackages(),
|
|
||||||
[&](const IndexRange& tq_range, const size_t pkg_idx) {
|
|
||||||
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
|
||||||
pools.AllClusters(pkg_idx).Run(
|
|
||||||
tq_range.begin(), tq_range.end(),
|
|
||||||
[&](const size_t tq_idx, const size_t cluster_idx) {
|
|
||||||
const HWY_MAYBE_UNUSED size_t cluster_base =
|
|
||||||
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
|
|
||||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
|
|
@ -202,8 +201,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
const size_t start_pos =
|
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||||
StartPos(pos, activations.config, layer_idx);
|
|
||||||
size_t last_pos = pos;
|
size_t last_pos = pos;
|
||||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||||
|
|
@ -211,43 +209,26 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
pools.Cluster(pkg_idx, cluster_idx)
|
float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
|
||||||
.Run(
|
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
|
||||||
0, layer_config.heads,
|
|
||||||
[&](const size_t head, size_t thread) HWY_ATTR {
|
|
||||||
#if PROFILER_ENABLED
|
|
||||||
const hwy::Zone zone(cluster_base + thread,
|
|
||||||
zone_id_par);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const size_t head_offset =
|
|
||||||
(head / kHeadGroups) * qkv_dim * 2;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.Row(tq_idx) + head * qkv_dim;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT att =
|
|
||||||
activations.att.Row(tq_idx) + head * seq_len;
|
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||||
|
|
||||||
// Make strided read-only views into the kv cache for
|
// Make strided read-only views into the kv cache for
|
||||||
// this query and head.
|
// this query and head.
|
||||||
const size_t kv_head_offset =
|
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||||
layer_idx * cache_layer_size + head_offset;
|
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
|
||||||
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||||
kv_cache.Stride());
|
|
||||||
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||||
kv_cache.Stride());
|
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||||
k, v, layer_idx, layer,
|
layer, activations, att, att_out, worker);
|
||||||
activations, att, att_out);
|
};
|
||||||
});
|
|
||||||
});
|
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools,
|
||||||
});
|
/*pkg_idx=*/0, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Different functions use different naming conventions for the number of
|
// Different functions use different naming conventions for the number of
|
||||||
|
|
@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t cache_pos =
|
const size_t cache_pos =
|
||||||
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
||||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||||
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||||
}
|
}
|
||||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
|
||||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||||
/*add=*/nullptr, env, kv_rows);
|
/*add=*/nullptr, env, kv_rows);
|
||||||
|
|
||||||
|
|
@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// tasks are very lightweight.
|
// tasks are very lightweight.
|
||||||
env.ctx.pools.Pool(0).Run(
|
env.ctx.pools.Pool(0).Run(
|
||||||
0, kv_heads * num_interleaved,
|
0, kv_heads * num_interleaved,
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
|
|
@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
|
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
||||||
|
thread);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
|
||||||
|
pos);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ namespace gcpp {
|
||||||
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out); \
|
float* HWY_RESTRICT att_out, size_t worker); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const LayerWeightsPtrs& layer, \
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
// Include guard (still compiled once per target)
|
// Include guard (still compiled once per target)
|
||||||
|
|
@ -43,9 +44,10 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||||
const T* HWY_RESTRICT c2, size_t count) {
|
const T* HWY_RESTRICT c2, const size_t count,
|
||||||
PROFILER_ZONE("Gen.Activation");
|
const size_t worker) {
|
||||||
|
PROFILER_ZONE2(worker, "Gen.Activation");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<T>;
|
using DF = hn::ScalableTag<T>;
|
||||||
using VF = hn::Vec<DF>;
|
using VF = hn::Vec<DF>;
|
||||||
|
|
@ -62,29 +64,33 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
||||||
|
|
||||||
// No C2 multiplier.
|
// No C2 multiplier.
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
void ActivationBatched(ActivationType activation, Mat& c1) {
|
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||||
|
[&](uint64_t task, size_t worker) {
|
||||||
// Cast to correct type so type deduction works.
|
// Cast to correct type so type deduction works.
|
||||||
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
|
Activation(activation, c1.Row(task),
|
||||||
c1.Cols());
|
static_cast<const T*>(nullptr), c1.Cols(), worker);
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
||||||
const Mat* c2) {
|
const Mat* c2, NestedPools& pools) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||||
Activation(activation, c1.Row(i), c2->Row(i), c1.Cols());
|
[&](uint64_t task, size_t worker) {
|
||||||
}
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||||
|
worker);
|
||||||
|
});
|
||||||
} else { // No multiplier
|
} else { // No multiplier
|
||||||
for (size_t i = 0; i < c1.Rows(); ++i) {
|
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||||
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
|
[&](uint64_t task, size_t worker) {
|
||||||
c1.Cols());
|
Activation(activation, c1.Row(task),
|
||||||
}
|
static_cast<const T*>(nullptr), c1.Cols(), worker);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -126,7 +132,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||||
activations.C2);
|
activations.C2);
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2);
|
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||||
|
env.ctx.pools);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
activations.C1);
|
activations.C1);
|
||||||
|
|
||||||
// Activation (Gelu), store in C1.
|
// Activation (Gelu), store in C1.
|
||||||
ActivationBatched(layer_config.activation, activations.C1);
|
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
||||||
|
|
|
||||||
|
|
@ -108,6 +108,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
// BindB/C if there is a single package: they will be a no-op.
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
BindB(b_trans, sizeof(TC), env.parallel);
|
BindB(b_trans, sizeof(TC), env.parallel);
|
||||||
BindC(C, env.parallel);
|
BindC(C, env.parallel);
|
||||||
|
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
||||||
|
|
@ -139,8 +140,8 @@ using SFP = SfpStream;
|
||||||
|
|
||||||
void BenchAllMatMul() {
|
void BenchAllMatMul() {
|
||||||
if (first_target == 0) first_target = HWY_TARGET;
|
if (first_target == 0) first_target = HWY_TARGET;
|
||||||
// Disable the best-target-only limitation.
|
// Comment out to disable the best-target-only limitation.
|
||||||
// if (HWY_TARGET != first_target) return;
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
|
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
|
||||||
|
|
|
||||||
|
|
@ -1320,20 +1320,7 @@ template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C) {
|
MatPtrT<TC>& C) {
|
||||||
RowPtrs<TC> C_rows(C.GetRowPtrs());
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]);
|
||||||
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
fprintf(stderr,
|
|
||||||
"MatMul perf warning: setting row pointers because "
|
|
||||||
"%s.AttachRowPtrs() was not called.\n",
|
|
||||||
C.Name());
|
|
||||||
}
|
|
||||||
HWY_DASSERT(C.HasPtr());
|
|
||||||
for (size_t r = 0; r < C.Rows(); ++r) {
|
|
||||||
env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
|
|
||||||
}
|
|
||||||
C_rows = RowPtrs<TC>(env.row_ptrs[0].get());
|
|
||||||
}
|
|
||||||
|
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
const size_t M = A.Rows();
|
const size_t M = A.Rows();
|
||||||
|
|
|
||||||
|
|
@ -428,7 +428,9 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
|
||||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM));
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // A
|
||||||
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxN)); // B
|
||||||
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
|
||||||
}
|
}
|
||||||
|
|
||||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
|
|
|
||||||
|
|
@ -682,10 +682,11 @@ struct MatMulEnv {
|
||||||
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
|
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
|
||||||
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
|
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
|
||||||
// writes to differing KV positions per query / output row.
|
// writes to differing KV positions per query / output row.
|
||||||
// The first entry is sufficient for any MatMul, but also potentially
|
// The first three allocations are sufficient for any A, B, C, respectively,
|
||||||
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
|
// but also potentially overwritten by each MatMul. Subsequent entries are
|
||||||
// and not overwritten. Per-tensor allocations make it likelier that asan
|
// precomputed for tensors and not overwritten. Per-tensor allocations make
|
||||||
// detects bugs such as use after free, overrun, and dangling references.
|
// it likelier that asan detects bugs such as use after free, overrun, and
|
||||||
|
// dangling references.
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
131
ops/ops-inl.h
131
ops/ops-inl.h
|
|
@ -191,8 +191,9 @@ namespace detail {
|
||||||
|
|
||||||
// Shared by RMSNorm and RMSNormInplace.
|
// Shared by RMSNorm and RMSNormInplace.
|
||||||
template <typename VT>
|
template <typename VT>
|
||||||
float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
|
||||||
PROFILER_ZONE("ops.RMSNormMul");
|
const HWY_MAYBE_UNUSED size_t worker) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.RMSNormMul");
|
||||||
|
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
|
||||||
|
|
@ -204,18 +205,18 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
|
||||||
|
|
||||||
// `x_ofs` is the offset within `x`, required for NuqStream.
|
// `x_ofs` is the offset within `x`, required for NuqStream.
|
||||||
template <typename XT, typename WT, typename OT>
|
template <typename XT, typename WT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const WT* HWY_RESTRICT weight,
|
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
|
||||||
size_t w_ofs, OT* HWY_RESTRICT out,
|
OT* HWY_RESTRICT out, const size_t size,
|
||||||
const size_t size) {
|
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||||
PROFILER_ZONE("ops.RMSNorm");
|
PROFILER_ZONE2(worker, "ops.RMSNorm");
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
|
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker));
|
||||||
|
|
||||||
const auto packed_x = MakeSpan(x, size);
|
const auto packed_x = MakeSpan(x, size);
|
||||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||||
|
|
@ -237,18 +238,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
|
||||||
|
|
||||||
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
||||||
template <typename WT, typename XT>
|
template <typename WT, typename XT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
size_t w_ofs,
|
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||||
XT* HWY_RESTRICT inout,
|
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
const size_t size) {
|
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
|
||||||
PROFILER_ZONE("ops.RMSNormInplace");
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
using VF = hn::Vec<decltype(df)>;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
|
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker));
|
||||||
|
|
||||||
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
const auto packed_w = MakeSpan(weight, w_ofs + size);
|
||||||
const auto packed_x = MakeSpan(inout, size);
|
const auto packed_x = MakeSpan(inout, size);
|
||||||
|
|
@ -407,8 +407,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, const size_t dim_qkv,
|
float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||||
PROFILER_ZONE("ops.Rope");
|
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.Rope");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -465,8 +466,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos,
|
||||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.RopeAndMulBy");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
|
@ -523,10 +525,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
float* HWY_RESTRICT out,
|
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
|
||||||
const size_t size) {
|
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
PROFILER_ZONE("ops.AddFrom");
|
PROFILER_ZONE2(worker, "ops.AddFrom");
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
@ -621,9 +623,10 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
||||||
size_t size) {
|
const float c, XT* HWY_RESTRICT x, const size_t size,
|
||||||
PROFILER_ZONE("ops.MulByConst");
|
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.MulByConst");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -659,12 +662,54 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as above, but without a separate output. Same as below without the add.
|
||||||
template <typename XT, typename OT>
|
template <typename XT, typename OT>
|
||||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||||
const XT* HWY_RESTRICT x,
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
OT* HWY_RESTRICT out,
|
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
size_t size) {
|
PROFILER_ZONE2(worker, "ops.MulByConstTo");
|
||||||
PROFILER_ZONE("ops.MulByConstAndAdd");
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
const VF v_c = hn::Set(df, c);
|
||||||
|
const auto packed_x = MakeSpan(x, size);
|
||||||
|
const auto packed_out = MakeSpan(out, size);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (size >= 2 * NF) {
|
||||||
|
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||||
|
VF x0, x1;
|
||||||
|
Decompress2(df, packed_x, i, x0, x1);
|
||||||
|
const VF out0 = hn::Mul(x0, v_c);
|
||||||
|
const VF out1 = hn::Mul(x1, v_c);
|
||||||
|
Compress2(df, out0, out1, packed_out, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
HWY_DASSERT(remaining < 2 * NF);
|
||||||
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
|
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
||||||
|
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||||
|
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||||
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
|
const VF x0 = hn::Load(df, buf_x);
|
||||||
|
const VF x1 = hn::Load(df, buf_x + NF);
|
||||||
|
const VF out0 = hn::Mul(x0, v_c);
|
||||||
|
const VF out1 = hn::Mul(x1, v_c);
|
||||||
|
Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0);
|
||||||
|
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename XT, typename OT>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||||
|
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t NF = hn::Lanes(df);
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
@ -709,8 +754,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f,
|
||||||
PROFILER_ZONE("ops.Softmax");
|
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
|
PROFILER_ZONE2(worker, "ops.Softmax");
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
@ -840,11 +886,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
||||||
return TokenAndProb{.token = argmax.token, .prob = prob};
|
return TokenAndProb{.token = argmax.token, .prob = prob};
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
static HWY_NOINLINE void LogitsSoftCap(
|
||||||
const size_t size,
|
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t max_pos) {
|
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||||
PROFILER_ZONE("ops.LogitsSoftCap");
|
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -852,22 +897,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
|
|
||||||
const float inv_cap = 1.0f / cap;
|
const float inv_cap = 1.0f / cap;
|
||||||
|
|
||||||
hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR {
|
hn::Transform(D(), x, size, [cap, inv_cap](D d, V v) HWY_ATTR {
|
||||||
return hn::Mul(hn::Set(d, cap),
|
return hn::Mul(hn::Set(d, cap),
|
||||||
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
LogitsSoftCap(cap, x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls LogitsSoftCap if cap != 0.0f.
|
// Calls LogitsSoftCap if cap != 0.0f.
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
const float cap, float* HWY_RESTRICT x, const size_t size) {
|
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||||
|
const size_t worker = 0) {
|
||||||
if (cap != 0.0f) {
|
if (cap != 0.0f) {
|
||||||
LogitsSoftCap(cap, x, size, size);
|
LogitsSoftCap(cap, x, size, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
21
util/mat.h
21
util/mat.h
|
|
@ -34,7 +34,7 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
|
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
|
||||||
// for C, in future also for A.
|
// for C (KV output), in future also for A or even B.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class RowPtrs {
|
class RowPtrs {
|
||||||
public:
|
public:
|
||||||
|
|
@ -317,6 +317,25 @@ class MatPtrT : public MatPtr {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RowPtrs<T> GetOrSetTempRowPtrs(
|
||||||
|
const MatPtrT<T>& mat,
|
||||||
|
const hwy::AlignedFreeUniquePtr<uint8_t*[]>& storage) {
|
||||||
|
if (HWY_LIKELY(mat.GetRowPtrs())) return RowPtrs<T>(mat.GetRowPtrs());
|
||||||
|
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"MatMul perf warning: setting row pointers because "
|
||||||
|
"%s.AttachRowPtrs() was not called.\n",
|
||||||
|
mat.Name());
|
||||||
|
}
|
||||||
|
HWY_DASSERT(mat.HasPtr());
|
||||||
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
|
storage[r] = reinterpret_cast<uint8_t*>(const_cast<T*>(mat.Row(r)));
|
||||||
|
}
|
||||||
|
return RowPtrs<T>(storage.get());
|
||||||
|
}
|
||||||
|
|
||||||
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
|
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
|
||||||
// types used as weights.
|
// types used as weights.
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
|
|
|
||||||
|
|
@ -321,6 +321,40 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||||
|
// over clusters of ONE package, then within each cluster.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
||||||
|
const Func& func) {
|
||||||
|
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
||||||
|
|
||||||
|
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
|
||||||
|
// there is only one cluster.
|
||||||
|
hwy::ThreadPool& all_clusters = pools.AllClusters(pkg_idx);
|
||||||
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
|
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0);
|
||||||
|
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
|
||||||
|
return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||||
|
func(task, pkg_base + thread);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign each cluster a sub-range.
|
||||||
|
const IndexRangePartition ranges =
|
||||||
|
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
|
||||||
|
ParallelizeOneRange(
|
||||||
|
ranges, all_clusters,
|
||||||
|
[&](const IndexRange& range, const size_t cluster_idx) {
|
||||||
|
hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
const size_t cluster_base =
|
||||||
|
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
|
||||||
|
cluster.Run(range.begin(), range.end(),
|
||||||
|
[&](uint64_t task, size_t thread) {
|
||||||
|
func(task, cluster_base + thread);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue