mirror of https://github.com/google/gemma.cpp.git
Add parameter for base_frequency to CreateInvTimeScale().
Extract a few local variables to make code easier to read (hopefully). PiperOrigin-RevId: 718749053
This commit is contained in:
parent
a133b3d062
commit
f37402da57
|
|
@ -74,17 +74,18 @@ struct Activations {
|
||||||
size_t seq_len;
|
size_t seq_len;
|
||||||
size_t cache_pos_size = 0;
|
size_t cache_pos_size = 0;
|
||||||
|
|
||||||
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
|
static RowVectorBatch<float> CreateInvTimescale(
|
||||||
PostQKType post_qk) {
|
size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) {
|
||||||
const size_t rope_dim =
|
const size_t rope_dim =
|
||||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||||
const float freq_exponents =
|
const double freq_exponents =
|
||||||
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
|
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results
|
||||||
// noticeably.
|
// noticeably.
|
||||||
inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents);
|
inv_timescale.Batch(0)[dim] =
|
||||||
|
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
|
||||||
}
|
}
|
||||||
return inv_timescale;
|
return inv_timescale;
|
||||||
}
|
}
|
||||||
|
|
@ -94,19 +95,20 @@ struct Activations {
|
||||||
const size_t model_dim = weights_config.model_dim;
|
const size_t model_dim = weights_config.model_dim;
|
||||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
const size_t vocab_size = weights_config.vocab_size;
|
const size_t vocab_size = weights_config.vocab_size;
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
const size_t heads = layer_config.heads;
|
||||||
|
|
||||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||||
q = RowVectorBatch<float>(
|
q = RowVectorBatch<float>(
|
||||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
|
Extents2D(batch_size, heads * layer_config.QStride()));
|
||||||
if (vocab_size > 0) {
|
if (vocab_size > 0) {
|
||||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||||
att = RowVectorBatch<float>(
|
att = RowVectorBatch<float>(
|
||||||
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
|
Extents2D(batch_size, heads * weights_config.seq_len));
|
||||||
att_out = RowVectorBatch<float>(
|
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
|
||||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
|
|
||||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||||
|
|
||||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||||
|
|
@ -122,7 +124,7 @@ struct Activations {
|
||||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
inv_timescale = CreateInvTimescale(qkv_dim, post_qk);
|
||||||
|
|
||||||
env = std::make_unique<MatMulEnv>(pools);
|
env = std::make_unique<MatMulEnv>(pools);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -216,15 +216,17 @@ class GemmaAttention {
|
||||||
template <typename U>
|
template <typename U>
|
||||||
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
|
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
|
||||||
const float mul, U* qk_out) {
|
const float mul, U* qk_out) {
|
||||||
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
(void)layer;
|
(void)layer;
|
||||||
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
||||||
hwy::CopyBytes(qk, qk_out, layer_config_.qkv_dim * sizeof(*qk));
|
hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
|
||||||
Rope(qk_out, layer_config_.qkv_dim / 2, inv_timescale, pos);
|
Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
|
||||||
MulByConst(mul, qk_out, layer_config_.qkv_dim);
|
MulByConst(mul, qk_out, qkv_dim);
|
||||||
} else {
|
} else {
|
||||||
RopeAndMulBy(mul, qk, layer_config_.qkv_dim, inv_timescale, pos, qk_out);
|
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -334,13 +336,14 @@ class GemmaAttention {
|
||||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||||
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
|
pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset;
|
||||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||||
const float score = Dot(q, k, layer_config_.qkv_dim);
|
const float score = Dot(q, k, qkv_dim);
|
||||||
head_att[pos] = score;
|
head_att[pos] = score;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -349,7 +352,7 @@ class GemmaAttention {
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||||
layer_ * cache_layer_size_ + head_offset;
|
layer_ * cache_layer_size_ + head_offset;
|
||||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||||
const float score = Dot(q, k, layer_config_.qkv_dim);
|
const float score = Dot(q, k, qkv_dim);
|
||||||
head_att[pos % activations_.seq_len] = score;
|
head_att[pos % activations_.seq_len] = score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -364,7 +367,8 @@ class GemmaAttention {
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len,
|
||||||
const KVCache& kv_cache,
|
const KVCache& kv_cache,
|
||||||
float* HWY_RESTRICT att_out) const {
|
float* HWY_RESTRICT att_out) const {
|
||||||
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
|
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||||
|
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
|
|
@ -372,8 +376,8 @@ class GemmaAttention {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
|
pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset;
|
||||||
const float* HWY_RESTRICT v =
|
const float* HWY_RESTRICT v =
|
||||||
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
|
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||||
MulByConstAndAdd(head_att[pos], v, att_out, layer_config_.qkv_dim);
|
MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
|
|
@ -381,9 +385,9 @@ class GemmaAttention {
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||||
layer * cache_layer_size_ + head_offset;
|
layer * cache_layer_size_ + head_offset;
|
||||||
const float* HWY_RESTRICT v =
|
const float* HWY_RESTRICT v =
|
||||||
kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim;
|
kv_cache.kv_cache.get() + kv_offset + qkv_dim;
|
||||||
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
|
MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out,
|
||||||
layer_config_.qkv_dim);
|
qkv_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -403,8 +407,8 @@ class GemmaAttention {
|
||||||
const size_t interleaved_idx = task / layer_config_.heads;
|
const size_t interleaved_idx = task / layer_config_.heads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries_;
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t head_offset =
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||||
|
|
@ -435,15 +439,14 @@ class GemmaAttention {
|
||||||
|
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations_.att_out.Batch(interleaved_idx) +
|
activations_.att_out.Batch(interleaved_idx) +
|
||||||
head * layer_config_.qkv_dim;
|
head * qkv_dim;
|
||||||
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
||||||
div_seq_len_, kv_cache, att_out);
|
div_seq_len_, kv_cache, att_out);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||||
// head_dim
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||||
// (`layer_config_.qkv_dim`) into output (`layer_out`).
|
|
||||||
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
|
||||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||||
// att_weights and att_out are concatenated heads, each of length
|
// att_weights and att_out are concatenated heads, each of length
|
||||||
|
|
@ -630,13 +633,12 @@ class VitAttention {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||||
// head_dim
|
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||||
// (`layer_config_.qkv_dim`) into output (`att_sums`).
|
|
||||||
HWY_NOINLINE void SumHeads() {
|
HWY_NOINLINE void SumHeads() {
|
||||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||||
// att_weights and att_out are concatenated heads, each of length
|
// att_weights and att_out are concatenated heads, each of length
|
||||||
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||||
// matmul output is the sum over heads.
|
// matmul output is the sum over heads.
|
||||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue