diff --git a/gemma/attention.cc b/gemma/attention.cc index 39c75e8..936db08 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -52,7 +52,7 @@ namespace HWY_NAMESPACE { static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT q, - const MatPtrT& k, float* HWY_RESTRICT att, + const MatPtrT& k, float* HWY_RESTRICT att, const size_t worker) { PROFILER_ZONE2(worker, "Gen.Attention.QDotK"); if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { @@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx, static HWY_INLINE void WeightedSumV( const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, - const MatPtrT& v, float* HWY_RESTRICT att_out, const size_t worker) { + const MatPtrT& v, float* HWY_RESTRICT att_out, const size_t worker) { if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if // we supported non-transposed B. @@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV( // in place for RMSNorm. void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, const size_t worker) { @@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // this query and head. const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, @@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Set up MatMul row pointers for writing to KV, which consists of // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound // because rows are computed modulo seq_len. - MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), + MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { @@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = activations.div_seq_len.Remainder(pos); auto& kv_cache = qbatch.KV(qi).kv_cache; - BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + + KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + head * qkv_dim * 2; diff --git a/gemma/attention.h b/gemma/attention.h index 5419b7f..42b2be1 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -30,7 +30,7 @@ namespace gcpp { namespace NAMESPACE { \ void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out, size_t worker); \ diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index aea5120..3de9e7d 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -25,6 +25,8 @@ namespace gcpp { +using KV_t = float; + struct KVCache { KVCache(const ModelConfig& config, const InferenceArgs& inference_args); @@ -42,7 +44,7 @@ struct KVCache { MatStorageT conv1d_cache; MatStorageT rglru_cache; // [griffin_layers, model_dim] - MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] private: // For use by other ctor and Copy()