diff --git a/gemma/activations.h b/gemma/activations.h index 20d938c..28e48ca 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -165,17 +165,32 @@ struct AttentionActivationsPtrs { } const ModelConfig& config; + // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q; + // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q_bf; + // Transposed query matrix for faster Q*K^T. MatPtrT q_T; + // Output of RMSNorm before attention, size batch_size x model_dim. MatPtrT pre_att_rms_out; + // Attention scores computed from Q*K^T, size batch_size x (q_heads * + // seq_len). MatPtrT att; + // Attention output computed from att * V, size batch_size x (q_heads * + // qkv_dim). MatPtrT att_out; + // Accumulation of attention outputs over heads, size batch_size x + // model_dim. MatPtrT att_sums; + // Inverse timescales for RoPE computation. MatPtrT inv_timescale; + // Inverse timescales for global RoPE computation. MatPtrT inv_timescale_global; + // Divisor for faster division by sequence length. hwy::Divisor div_seq_len; + // Divisor for faster division by number of heads. hwy::Divisor div_heads; + // Query scaling factor for attention computation. float query_scale; }; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 803067b..55d0522 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -428,9 +428,37 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, return scale; } -// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to -// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, -// max_last_pos]. +// Implements flash attention for a strip of 4 query vectors. +// It iterates through timesteps in K from `start_pos` up to `max_last_pos`. +// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows +// by NF timesteps in K for efficiency while timesteps between `min_last_pos + +// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos` +// values within the strip. +// (*) Actually, it only iterates through +// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled +// computation can, for obvious reasons, only process an integer number of +// tiles. +// +// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format. +// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query +// vectors to be processed in this tile. +// @param k Key matrix [seq_len, qkv_dim] from KV cache. +// @param start_pos The first token position in the KV cache to attend to. +// @param last_pos An array of 4 indices giving the last token position +// (inclusive) that each of the 4 queries may attend to. +// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this +// position can be processed efficiently in batches. +// @param max_last_pos The maximum value in `last_pos`. Timesteps between +// `min_last_pos + 1` and this position are processed individually to +// respect each query's `last_pos` limit. +// @param v Value matrix [seq_len, qkv_dim] from KV cache. +// @param layer_idx The index of the current transformer layer. +// @param activations Attention configurations and buffers. +// @param att_out Output buffer for attention results. +// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output +// vectors. +// @param ctx Threading context. +// @param worker Worker thread index. Tile4FlashState TileFlashAttention4( const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, const size_t start_pos, diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h index 8edae11..73563fe 100644 --- a/gemma/flash_structs.h +++ b/gemma/flash_structs.h @@ -7,8 +7,16 @@ namespace gcpp { +// State for computing softmax in a streaming ("online") manner, +// avoiding large intermediate values by subtracting the running maximum. +// For a sequence x_1, ..., x_n: +// m_i = max(m_{i-1}, x_i) +// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i) +// softmax_i = exp(x_i - m_i) / d_i struct OnlineSoftmaxState { + // Maximum logit value encountered so far. float max = -std::numeric_limits::max() / 2.0f; + // Sum of exponentials scaled by exp(-max). float d = 0.0f; }; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 7ad8e20..6eac06f 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -61,6 +61,9 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +// Computes C = A * B + add via MatMulStatic. +// This function uses CallUpcasted to dispatch to the correct MatMulStatic +// instantiation based on the runtime type of B. template MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, const float* HWY_RESTRICT add, MatMulEnv& env,