diff --git a/gemma/activations.h b/gemma/activations.h index 7a53aa1..9d6ccb5 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -74,17 +74,18 @@ struct Activations { size_t seq_len; size_t cache_pos_size = 0; - static RowVectorBatch CreateInvTimescale(size_t qkv_dim, - PostQKType post_qk) { + static RowVectorBatch CreateInvTimescale( + size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) { const size_t rope_dim = post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; RowVectorBatch inv_timescale(Extents2D(1, rope_dim / 2)); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { - const float freq_exponents = - static_cast(2 * dim) / static_cast(rope_dim); + const double freq_exponents = + static_cast(2 * dim) / static_cast(rope_dim); // Replacing with expf(ln(1E4) * freq_exponents) changes results // noticeably. - inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents); + inv_timescale.Batch(0)[dim] = + static_cast(1.0 / std::pow(base_frequency, freq_exponents)); } return inv_timescale; } @@ -94,19 +95,20 @@ struct Activations { const size_t model_dim = weights_config.model_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim; const size_t vocab_size = weights_config.vocab_size; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t heads = layer_config.heads; x = RowVectorBatch(Extents2D(batch_size, model_dim)); q = RowVectorBatch( - Extents2D(batch_size, layer_config.heads * layer_config.QStride())); + Extents2D(batch_size, heads * layer_config.QStride())); if (vocab_size > 0) { logits = RowVectorBatch(Extents2D(batch_size, vocab_size)); } pre_att_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); att = RowVectorBatch( - Extents2D(batch_size, layer_config.heads * weights_config.seq_len)); - att_out = RowVectorBatch( - Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim)); + Extents2D(batch_size, heads * weights_config.seq_len)); + att_out = RowVectorBatch(Extents2D(batch_size, heads * qkv_dim)); att_sums = RowVectorBatch(Extents2D(batch_size, model_dim)); bf_pre_ffw_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); @@ -122,7 +124,7 @@ struct Activations { RowVectorBatch(Extents2D(batch_size, model_dim)); } - inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); + inv_timescale = CreateInvTimescale(qkv_dim, post_qk); env = std::make_unique(pools); } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 820ee61..43d2f77 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -216,15 +216,17 @@ class GemmaAttention { template HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer, const float mul, U* qk_out) { + // qk is either q or k, so qkv_dim is the length we operate on. + const size_t qkv_dim = layer_config_.qkv_dim; const float* inv_timescale = activations_.inv_timescale.Const(); // PostQKType::Rope (void)layer; if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { - hwy::CopyBytes(qk, qk_out, layer_config_.qkv_dim * sizeof(*qk)); - Rope(qk_out, layer_config_.qkv_dim / 2, inv_timescale, pos); - MulByConst(mul, qk_out, layer_config_.qkv_dim); + hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk)); + Rope(qk_out, qkv_dim / 2, inv_timescale, pos); + MulByConst(mul, qk_out, qkv_dim); } else { - RopeAndMulBy(mul, qk, layer_config_.qkv_dim, inv_timescale, pos, qk_out); + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out); } } @@ -334,13 +336,14 @@ class GemmaAttention { HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const size_t head_offset, const float* HWY_RESTRICT q, const KVCache& kv_cache, float* HWY_RESTRICT head_att) { + const size_t qkv_dim = layer_config_.qkv_dim; if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t kv_offset = pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset; const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, layer_config_.qkv_dim); + const float score = Dot(q, k, qkv_dim); head_att[pos] = score; } } else { @@ -349,7 +352,7 @@ class GemmaAttention { const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset; const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, layer_config_.qkv_dim); + const float score = Dot(q, k, qkv_dim); head_att[pos % activations_.seq_len] = score; } } @@ -364,7 +367,8 @@ class GemmaAttention { const hwy::Divisor& div_seq_len, const KVCache& kv_cache, float* HWY_RESTRICT att_out) const { - hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out)); + const size_t qkv_dim = layer_config_.qkv_dim; + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. @@ -372,8 +376,8 @@ class GemmaAttention { const size_t kv_offset = pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset; const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim; - MulByConstAndAdd(head_att[pos], v, att_out, layer_config_.qkv_dim); + kv_cache.kv_cache.get() + kv_offset + qkv_dim; + MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim); } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -381,9 +385,9 @@ class GemmaAttention { const size_t kv_offset = cache_pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset; const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + layer_config_.qkv_dim; + kv_cache.kv_cache.get() + kv_offset + qkv_dim; MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out, - layer_config_.qkv_dim); + qkv_dim); } } } @@ -403,8 +407,8 @@ class GemmaAttention { const size_t interleaved_idx = task / layer_config_.heads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; - const size_t head_offset = - (head / kHeadGroups) * layer_config_.qkv_dim * 2; + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT q = activations_.q.Batch(interleaved_idx) + head * q_stride_; @@ -435,15 +439,14 @@ class GemmaAttention { float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx) + - head * layer_config_.qkv_dim; + head * qkv_dim; WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset, div_seq_len_, kv_cache, att_out); }); } // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim - // (`layer_config_.qkv_dim`) into output (`layer_out`). + // head_dim (`qkv_dim`) into output (`layer_out`). HWY_NOINLINE void SumHeads(const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.SumHeads"); // att_weights and att_out are concatenated heads, each of length @@ -630,13 +633,12 @@ class VitAttention { } // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim - // (`layer_config_.qkv_dim`) into output (`att_sums`). + // head_dim (`qkv_dim`) into output (`att_sums`). HWY_NOINLINE void SumHeads() { PROFILER_ZONE("Gen.VitAttention.SumHeads"); auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); // att_weights and att_out are concatenated heads, each of length - // layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] + // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);