mirror of https://github.com/google/gemma.cpp.git
parent
b8f6be72b1
commit
49d420aeaf
|
|
@ -165,17 +165,32 @@ struct AttentionActivationsPtrs {
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
|
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
|
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||||
MatPtrT<BF16> q_bf;
|
MatPtrT<BF16> q_bf;
|
||||||
|
// Transposed query matrix for faster Q*K^T.
|
||||||
MatPtrT<BF16> q_T;
|
MatPtrT<BF16> q_T;
|
||||||
|
// Output of RMSNorm before attention, size batch_size x model_dim.
|
||||||
MatPtrT<float> pre_att_rms_out;
|
MatPtrT<float> pre_att_rms_out;
|
||||||
|
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
|
||||||
|
// seq_len).
|
||||||
MatPtrT<float> att;
|
MatPtrT<float> att;
|
||||||
|
// Attention output computed from att * V, size batch_size x (q_heads *
|
||||||
|
// qkv_dim).
|
||||||
MatPtrT<float> att_out;
|
MatPtrT<float> att_out;
|
||||||
|
// Accumulation of attention outputs over heads, size batch_size x
|
||||||
|
// model_dim.
|
||||||
MatPtrT<BF16> att_sums;
|
MatPtrT<BF16> att_sums;
|
||||||
|
// Inverse timescales for RoPE computation.
|
||||||
MatPtrT<float> inv_timescale;
|
MatPtrT<float> inv_timescale;
|
||||||
|
// Inverse timescales for global RoPE computation.
|
||||||
MatPtrT<float> inv_timescale_global;
|
MatPtrT<float> inv_timescale_global;
|
||||||
|
// Divisor for faster division by sequence length.
|
||||||
hwy::Divisor div_seq_len;
|
hwy::Divisor div_seq_len;
|
||||||
|
// Divisor for faster division by number of heads.
|
||||||
hwy::Divisor div_heads;
|
hwy::Divisor div_heads;
|
||||||
|
// Query scaling factor for attention computation.
|
||||||
float query_scale;
|
float query_scale;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -428,9 +428,37 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
return scale;
|
return scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
// Implements flash attention for a strip of 4 query vectors.
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
|
||||||
// max_last_pos].
|
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
|
||||||
|
// by NF timesteps in K for efficiency while timesteps between `min_last_pos +
|
||||||
|
// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos`
|
||||||
|
// values within the strip.
|
||||||
|
// (*) Actually, it only iterates through
|
||||||
|
// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled
|
||||||
|
// computation can, for obvious reasons, only process an integer number of
|
||||||
|
// tiles.
|
||||||
|
//
|
||||||
|
// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format.
|
||||||
|
// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query
|
||||||
|
// vectors to be processed in this tile.
|
||||||
|
// @param k Key matrix [seq_len, qkv_dim] from KV cache.
|
||||||
|
// @param start_pos The first token position in the KV cache to attend to.
|
||||||
|
// @param last_pos An array of 4 indices giving the last token position
|
||||||
|
// (inclusive) that each of the 4 queries may attend to.
|
||||||
|
// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this
|
||||||
|
// position can be processed efficiently in batches.
|
||||||
|
// @param max_last_pos The maximum value in `last_pos`. Timesteps between
|
||||||
|
// `min_last_pos + 1` and this position are processed individually to
|
||||||
|
// respect each query's `last_pos` limit.
|
||||||
|
// @param v Value matrix [seq_len, qkv_dim] from KV cache.
|
||||||
|
// @param layer_idx The index of the current transformer layer.
|
||||||
|
// @param activations Attention configurations and buffers.
|
||||||
|
// @param att_out Output buffer for attention results.
|
||||||
|
// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output
|
||||||
|
// vectors.
|
||||||
|
// @param ctx Threading context.
|
||||||
|
// @param worker Worker thread index.
|
||||||
Tile4FlashState TileFlashAttention4(
|
Tile4FlashState TileFlashAttention4(
|
||||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,16 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// State for computing softmax in a streaming ("online") manner,
|
||||||
|
// avoiding large intermediate values by subtracting the running maximum.
|
||||||
|
// For a sequence x_1, ..., x_n:
|
||||||
|
// m_i = max(m_{i-1}, x_i)
|
||||||
|
// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i)
|
||||||
|
// softmax_i = exp(x_i - m_i) / d_i
|
||||||
struct OnlineSoftmaxState {
|
struct OnlineSoftmaxState {
|
||||||
|
// Maximum logit value encountered so far.
|
||||||
float max = -std::numeric_limits<float>::max() / 2.0f;
|
float max = -std::numeric_limits<float>::max() / 2.0f;
|
||||||
|
// Sum of exponentials scaled by exp(-max).
|
||||||
float d = 0.0f;
|
float d = 0.0f;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,9 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
// Computes C = A * B + add via MatMulStatic.
|
||||||
|
// This function uses CallUpcasted to dispatch to the correct MatMulStatic
|
||||||
|
// instantiation based on the runtime type of B.
|
||||||
template <typename TA, typename TC>
|
template <typename TA, typename TC>
|
||||||
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue