Add parameter for base_frequency to CreateInvTimeScale().

Extract a few local variables to make code easier to read (hopefully).

PiperOrigin-RevId: 718749053
This commit is contained in:
Daniel Keysers 2025-01-23 00:56:04 -08:00 committed by Copybara-Service
parent a133b3d062
commit f37402da57
2 changed files with 33 additions and 29 deletions

View File

@ -74,17 +74,18 @@ struct Activations {
size_t seq_len; size_t seq_len;
size_t cache_pos_size = 0; size_t cache_pos_size = 0;
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim, static RowVectorBatch<float> CreateInvTimescale(
PostQKType post_qk) { size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) {
const size_t rope_dim = const size_t rope_dim =
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2)); RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents = const double freq_exponents =
static_cast<float>(2 * dim) / static_cast<float>(rope_dim); static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results // Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably. // noticeably.
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents); inv_timescale.Batch(0)[dim] =
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
} }
return inv_timescale; return inv_timescale;
} }
@ -94,19 +95,20 @@ struct Activations {
const size_t model_dim = weights_config.model_dim; const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size; const size_t vocab_size = weights_config.vocab_size;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim)); x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>( q = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.QStride())); Extents2D(batch_size, heads * layer_config.QStride()));
if (vocab_size > 0) { if (vocab_size > 0) {
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size)); logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
} }
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim)); pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>( att = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * weights_config.seq_len)); Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>( att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim)); att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim)); bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
@ -122,7 +124,7 @@ struct Activations {
RowVectorBatch<float>(Extents2D(batch_size, model_dim)); RowVectorBatch<float>(Extents2D(batch_size, model_dim));
} }
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); inv_timescale = CreateInvTimescale(qkv_dim, post_qk);
env = std::make_unique<MatMulEnv>(pools); env = std::make_unique<MatMulEnv>(pools);
} }

View File

@ -216,15 +216,17 @@ class GemmaAttention {
template <typename U> template <typename U>
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer, HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
const float mul, U* qk_out) { const float mul, U* qk_out) {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const(); const float* inv_timescale = activations_.inv_timescale.Const();
// PostQKType::Rope // PostQKType::Rope
(void)layer; (void)layer;
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
hwy::CopyBytes(qk, qk_out, layer_config_.qkv_dim * sizeof(*qk)); hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
Rope(qk_out, layer_config_.qkv_dim / 2, inv_timescale, pos); Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, layer_config_.qkv_dim); MulByConst(mul, qk_out, qkv_dim);
} else { } else {
RopeAndMulBy(mul, qk, layer_config_.qkv_dim, inv_timescale, pos, qk_out); RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
} }
} }
@ -334,13 +336,14 @@ class GemmaAttention {
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const size_t head_offset, const float* HWY_RESTRICT q, const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) { const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
const size_t qkv_dim = layer_config_.qkv_dim;
if (HWY_LIKELY(last_pos < activations_.seq_len)) { if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset = const size_t kv_offset =
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset; pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, layer_config_.qkv_dim); const float score = Dot(q, k, qkv_dim);
head_att[pos] = score; head_att[pos] = score;
} }
} else { } else {
@ -349,7 +352,7 @@ class GemmaAttention {
const size_t kv_offset = cache_pos * cache_pos_size_ + const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ + head_offset; layer_ * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, layer_config_.qkv_dim); const float score = Dot(q, k, qkv_dim);
head_att[pos % activations_.seq_len] = score; head_att[pos % activations_.seq_len] = score;
} }
} }
@ -364,7 +367,8 @@ class GemmaAttention {
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const KVCache& kv_cache, const KVCache& kv_cache,
float* HWY_RESTRICT att_out) const { float* HWY_RESTRICT att_out) const {
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out)); const size_t qkv_dim = layer_config_.qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
if (HWY_LIKELY(last_pos < activations_.seq_len)) { if (HWY_LIKELY(last_pos < activations_.seq_len)) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
@ -372,8 +376,8 @@ class GemmaAttention {
const size_t kv_offset = const size_t kv_offset =
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset; pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v = const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim; kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos], v, att_out, layer_config_.qkv_dim); MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
} }
} else { } else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -381,9 +385,9 @@ class GemmaAttention {
const size_t kv_offset = cache_pos * cache_pos_size_ + const size_t kv_offset = cache_pos * cache_pos_size_ +
layer * cache_layer_size_ + head_offset; layer * cache_layer_size_ + head_offset;
const float* HWY_RESTRICT v = const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim; kv_cache.kv_cache.get() + kv_offset + qkv_dim;
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out, MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
layer_config_.qkv_dim); qkv_dim);
} }
} }
} }
@ -403,8 +407,8 @@ class GemmaAttention {
const size_t interleaved_idx = task / layer_config_.heads; const size_t interleaved_idx = task / layer_config_.heads;
const size_t query_idx = interleaved_idx % num_queries_; const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset = const size_t qkv_dim = layer_config_.qkv_dim;
(head / kHeadGroups) * layer_config_.qkv_dim * 2; const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_; activations_.q.Batch(interleaved_idx) + head * q_stride_;
@ -435,15 +439,14 @@ class GemmaAttention {
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) + activations_.att_out.Batch(interleaved_idx) +
head * layer_config_.qkv_dim; head * qkv_dim;
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset, WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
div_seq_len_, kv_cache, att_out); div_seq_len_, kv_cache, att_out);
}); });
} }
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim // head_dim (`qkv_dim`) into output (`layer_out`).
// (`layer_config_.qkv_dim`) into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) { HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.SumHeads"); PROFILER_ZONE("Gen.Attention.SumHeads");
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
@ -630,13 +633,12 @@ class VitAttention {
} }
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim // head_dim (`qkv_dim`) into output (`att_sums`).
// (`layer_config_.qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() { HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads"); PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);