diff --git a/BUILD.bazel b/BUILD.bazel index d14002f..aae230e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -543,6 +543,7 @@ cc_library( "gemma/activations.h", "gemma/attention.h", "gemma/flash_attention.h", + "gemma/flash_structs.h", "gemma/gemma.h", "gemma/vit.h", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index a707078..13398fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ set(SOURCES gemma/configs.h gemma/flash_attention.cc gemma/flash_attention.h + gemma/flash_structs.h gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 5392ec0..d2d13f7 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -21,6 +21,7 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/flash_structs.h" #include "util/threading_context.h" #include "util/zones.h" #ifndef HWY_DISABLED_TARGETS @@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, // Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. -void TileFlashAttention4(const MatPtrT& q, - const uint32_t* HWY_RESTRICT q_offsets, - const MatPtrT& k, const size_t start_pos, - const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, - MatPtrT& att_out, - const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { +Tile4FlashState TileFlashAttention4( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); using DF = hn::ScalableTag; const DF df; @@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT& q, hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], v.Cols() * sizeof(att_out.Row(0)[0])); } - float old_m0 = -std::numeric_limits::max() / 2.0f; - float old_m1 = -std::numeric_limits::max() / 2.0f; - float old_m2 = -std::numeric_limits::max() / 2.0f; - float old_m3 = -std::numeric_limits::max() / 2.0f; - float old_d0 = 0.0f; - float old_d1 = 0.0f; - float old_d2 = 0.0f; - float old_d3 = 0.0f; + Tile4FlashState state; size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { int32_t k_offsets[kMaxNF]; @@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT& q, x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); } - scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); - scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); - scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); - scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); + scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max, + state.row_states[0].d); + scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max, + state.row_states[1].d); + scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max, + state.row_states[2].d); + scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max, + state.row_states[3].d); MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), out_offsets, v.Cols()); position += kHTileSize; @@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT& q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); float x0 = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, + SingleFlashAttentionStep(x0, activations.config.att_cap, + state.row_states[0].max, state.row_states[0].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[0]); } @@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT& q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); float x1 = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, + SingleFlashAttentionStep(x1, activations.config.att_cap, + state.row_states[1].max, state.row_states[1].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[1]); } @@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT& q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); float x2 = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, + SingleFlashAttentionStep(x2, activations.config.att_cap, + state.row_states[2].max, state.row_states[2].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[2]); } @@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT& q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); float x3 = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, + SingleFlashAttentionStep(x3, activations.config.att_cap, + state.row_states[3].max, state.row_states[3].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[3]); } ++position; } + return state; } // Rounds n to a number that can be used as the number of Q rows in a tile diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index ab3a395..099fc69 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -20,35 +20,47 @@ #include +#include + +#include "gemma/flash_structs.h" #include "gemma/gemma.h" #include "hwy/highway.h" namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding( \ - size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ - \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const float* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, \ - const AttentionActivationsPtrs& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const float* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ + \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ + ThreadingContext& ctx, const size_t worker); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h new file mode 100644 index 0000000..8edae11 --- /dev/null +++ b/gemma/flash_structs.h @@ -0,0 +1,23 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ + +#include + +#include + +namespace gcpp { + +struct OnlineSoftmaxState { + float max = -std::numeric_limits::max() / 2.0f; + float d = 0.0f; +}; + +static constexpr size_t kVTileSize4 = 4; + +struct Tile4FlashState { + OnlineSoftmaxState row_states[kVTileSize4]; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_