mirror of https://github.com/google/gemma.cpp.git
Add parameter for base_frequency to CreateInvTimeScale().
Extract a few local variables to make code easier to read (hopefully). PiperOrigin-RevId: 718749053
This commit is contained in:
parent
a133b3d062
commit
f37402da57
|
|
@ -74,17 +74,18 @@ struct Activations {
|
|||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
|
||||
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
|
||||
PostQKType post_qk) {
|
||||
static RowVectorBatch<float> CreateInvTimescale(
|
||||
size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) {
|
||||
const size_t rope_dim =
|
||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const float freq_exponents =
|
||||
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
|
||||
const double freq_exponents =
|
||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||
// noticeably.
|
||||
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents);
|
||||
inv_timescale.Batch(0)[dim] =
|
||||
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||
}
|
||||
return inv_timescale;
|
||||
}
|
||||
|
|
@ -94,19 +95,20 @@ struct Activations {
|
|||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
|
||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
|
||||
Extents2D(batch_size, heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
|
||||
Extents2D(batch_size, heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
|
||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||
|
|
@ -122,7 +124,7 @@ struct Activations {
|
|||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
||||
inv_timescale = CreateInvTimescale(qkv_dim, post_qk);
|
||||
|
||||
env = std::make_unique<MatMulEnv>(pools);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -216,15 +216,17 @@ class GemmaAttention {
|
|||
template <typename U>
|
||||
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
|
||||
const float mul, U* qk_out) {
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||
// PostQKType::Rope
|
||||
(void)layer;
|
||||
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
||||
hwy::CopyBytes(qk, qk_out, layer_config_.qkv_dim * sizeof(*qk));
|
||||
Rope(qk_out, layer_config_.qkv_dim / 2, inv_timescale, pos);
|
||||
MulByConst(mul, qk_out, layer_config_.qkv_dim);
|
||||
hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
|
||||
Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
|
||||
MulByConst(mul, qk_out, qkv_dim);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, layer_config_.qkv_dim, inv_timescale, pos, qk_out);
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -334,13 +336,14 @@ class GemmaAttention {
|
|||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t kv_offset =
|
||||
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, layer_config_.qkv_dim);
|
||||
const float score = Dot(q, k, qkv_dim);
|
||||
head_att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -349,7 +352,7 @@ class GemmaAttention {
|
|||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer_ * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, layer_config_.qkv_dim);
|
||||
const float score = Dot(q, k, qkv_dim);
|
||||
head_att[pos % activations_.seq_len] = score;
|
||||
}
|
||||
}
|
||||
|
|
@ -364,7 +367,8 @@ class GemmaAttention {
|
|||
const hwy::Divisor& div_seq_len,
|
||||
const KVCache& kv_cache,
|
||||
float* HWY_RESTRICT att_out) const {
|
||||
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
|
||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
|
|
@ -372,8 +376,8 @@ class GemmaAttention {
|
|||
const size_t kv_offset =
|
||||
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
|
||||
MulByConstAndAdd(head_att[pos], v, att_out, layer_config_.qkv_dim);
|
||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
|
|
@ -381,9 +385,9 @@ class GemmaAttention {
|
|||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer * cache_layer_size_ + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
|
||||
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
|
||||
layer_config_.qkv_dim);
|
||||
qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -403,8 +407,8 @@ class GemmaAttention {
|
|||
const size_t interleaved_idx = task / layer_config_.heads;
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
const size_t head_offset =
|
||||
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||
|
|
@ -435,15 +439,14 @@ class GemmaAttention {
|
|||
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(interleaved_idx) +
|
||||
head * layer_config_.qkv_dim;
|
||||
head * qkv_dim;
|
||||
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
||||
div_seq_len_, kv_cache, att_out);
|
||||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim
|
||||
// (`layer_config_.qkv_dim`) into output (`layer_out`).
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
|
|
@ -630,13 +633,12 @@ class VitAttention {
|
|||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim
|
||||
// (`layer_config_.qkv_dim`) into output (`att_sums`).
|
||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||
|
|
|
|||
Loading…
Reference in New Issue