Added access to flash attention internals to TileFlashAttention4

PiperOrigin-RevId: 826011137
This commit is contained in:
Ray Smith 2025-10-30 06:49:33 -07:00 committed by Copybara-Service
parent ee7d79c0a6
commit 8a100c1e8d
5 changed files with 87 additions and 49 deletions

View File

@ -543,6 +543,7 @@ cc_library(
"gemma/activations.h", "gemma/activations.h",
"gemma/attention.h", "gemma/attention.h",
"gemma/flash_attention.h", "gemma/flash_attention.h",
"gemma/flash_structs.h",
"gemma/gemma.h", "gemma/gemma.h",
"gemma/vit.h", "gemma/vit.h",
], ],

View File

@ -82,6 +82,7 @@ set(SOURCES
gemma/configs.h gemma/configs.h
gemma/flash_attention.cc gemma/flash_attention.cc
gemma/flash_attention.h gemma/flash_attention.h
gemma/flash_structs.h
gemma/gemma_args.h gemma/gemma_args.h
gemma/gemma-inl.h gemma/gemma-inl.h
gemma/gemma.cc gemma/gemma.cc

View File

@ -21,6 +21,7 @@
#include <limits> #include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h" #include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
void TileFlashAttention4(const MatPtrT<float>& q, Tile4FlashState TileFlashAttention4(
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t min_last_pos, const size_t max_last_pos, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const MatPtrT<KV_t>& v, const size_t layer_idx, const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const AttentionActivationsPtrs& activations, const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
MatPtrT<float>& att_out, const size_t worker) {
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0])); v.Cols() * sizeof(att_out.Row(0)[0]));
} }
float old_m0 = -std::numeric_limits<float>::max() / 2.0f; Tile4FlashState state;
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
float old_d0 = 0.0f;
float old_d1 = 0.0f;
float old_d2 = 0.0f;
float old_d3 = 0.0f;
size_t position = start_pos; size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) { while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF]; int32_t k_offsets[kMaxNF];
@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
} }
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); state.row_states[0].d);
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); state.row_states[1].d);
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
state.row_states[2].d);
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
state.row_states[3].d);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols()); out_offsets, v.Cols());
position += kHTileSize; position += kHTileSize;
@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 = float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]); att_out.Row(0) + out_offsets[0]);
} }
@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 = float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]); att_out.Row(0) + out_offsets[1]);
} }
@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 = float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]); att_out.Row(0) + out_offsets[2]);
} }
@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 = float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]); att_out.Row(0) + out_offsets[3]);
} }
++position; ++position;
} }
return state;
} }
// Rounds n to a number that can be used as the number of Q rows in a tile // Rounds n to a number that can be used as the number of Q rows in a tile

View File

@ -20,6 +20,9 @@
#include <stddef.h> #include <stddef.h>
#include <cstdint>
#include "gemma/flash_structs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -41,6 +44,15 @@ namespace gcpp {
float* HWY_RESTRICT att_out, \ float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \ ThreadingContext& ctx, size_t worker); \
\ \
Tile4FlashState TileFlashAttention4( \
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \

23
gemma/flash_structs.h Normal file
View File

@ -0,0 +1,23 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <limits>
namespace gcpp {
struct OnlineSoftmaxState {
float max = -std::numeric_limits<float>::max() / 2.0f;
float d = 0.0f;
};
static constexpr size_t kVTileSize4 = 4;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_