Add some comments.

PiperOrigin-RevId: 834173319
This commit is contained in:
Martin Stolle 2025-11-19 01:08:38 -08:00 committed by Copybara-Service
parent b8f6be72b1
commit 49d420aeaf
4 changed files with 57 additions and 3 deletions

View File

@ -165,17 +165,32 @@ struct AttentionActivationsPtrs {
} }
const ModelConfig& config; const ModelConfig& config;
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<float> q; MatPtrT<float> q;
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<BF16> q_bf; MatPtrT<BF16> q_bf;
// Transposed query matrix for faster Q*K^T.
MatPtrT<BF16> q_T; MatPtrT<BF16> q_T;
// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out; MatPtrT<float> pre_att_rms_out;
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
// seq_len).
MatPtrT<float> att; MatPtrT<float> att;
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out; MatPtrT<float> att_out;
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
MatPtrT<BF16> att_sums; MatPtrT<BF16> att_sums;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale; MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.
MatPtrT<float> inv_timescale_global; MatPtrT<float> inv_timescale_global;
// Divisor for faster division by sequence length.
hwy::Divisor div_seq_len; hwy::Divisor div_seq_len;
// Divisor for faster division by number of heads.
hwy::Divisor div_heads; hwy::Divisor div_heads;
// Query scaling factor for attention computation.
float query_scale; float query_scale;
}; };

View File

@ -428,9 +428,37 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
return scale; return scale;
} }
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // Implements flash attention for a strip of 4 query vectors.
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
// max_last_pos]. // Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
// by NF timesteps in K for efficiency while timesteps between `min_last_pos +
// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
// values within the strip.
// (*) Actually, it only iterates through
// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
// computation can, for obvious reasons, only process an integer number of
// tiles.
//
// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
// vectors to be processed in this tile.
// @param k Key matrix [seq_len, qkv_dim] from KV cache.
// @param start_pos The first token position in the KV cache to attend to.
// @param last_pos An array of 4 indices giving the last token position
// (inclusive) that each of the 4 queries may attend to.
// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
// position can be processed efficiently in batches.
// @param max_last_pos The maximum value in `last_pos`. Timesteps between
// `min_last_pos + 1` and this position are processed individually to
// respect each query's `last_pos` limit.
// @param v Value matrix [seq_len, qkv_dim] from KV cache.
// @param layer_idx The index of the current transformer layer.
// @param activations Attention configurations and buffers.
// @param att_out Output buffer for attention results.
// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
// vectors.
// @param ctx Threading context.
// @param worker Worker thread index.
Tile4FlashState TileFlashAttention4( Tile4FlashState TileFlashAttention4(
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos, const MatPtrT<KV_t>& k, const size_t start_pos,

View File

@ -7,8 +7,16 @@
namespace gcpp { namespace gcpp {
// State for computing softmax in a streaming ("online") manner,
// avoiding large intermediate values by subtracting the running maximum.
// For a sequence x_1, ..., x_n:
// m_i = max(m_{i-1}, x_i)
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
// softmax_i = exp(x_i - m_i) / d_i
struct OnlineSoftmaxState { struct OnlineSoftmaxState {
// Maximum logit value encountered so far.
float max = -std::numeric_limits<float>::max() / 2.0f; float max = -std::numeric_limits<float>::max() / 2.0f;
// Sum of exponentials scaled by exp(-max).
float d = 0.0f; float d = 0.0f;
}; };

View File

@ -61,6 +61,9 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Computes C = A * B + add via MatMulStatic.
// This function uses CallUpcasted to dispatch to the correct MatMulStatic
// instantiation based on the runtime type of B.
template <typename TA, typename TC> template <typename TA, typename TC>
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B, MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,