Added access to flash attention internals to TileFlashAttention4

PiperOrigin-RevId: 826011137
This commit is contained in:
Ray Smith 2025-10-30 06:49:33 -07:00 committed by Copybara-Service
parent ee7d79c0a6
commit 8a100c1e8d
5 changed files with 87 additions and 49 deletions

View File

@ -543,6 +543,7 @@ cc_library(
"gemma/activations.h",
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
"gemma/gemma.h",
"gemma/vit.h",
],

View File

@ -82,6 +82,7 @@ set(SOURCES
gemma/configs.h
gemma/flash_attention.cc
gemma/flash_attention.h
gemma/flash_structs.h
gemma/gemma_args.h
gemma/gemma-inl.h
gemma/gemma.cc

View File

@ -21,6 +21,7 @@
#include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#include "util/threading_context.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention4(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
Tile4FlashState TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>;
const DF df;
@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0]));
}
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
float old_d0 = 0.0f;
float old_d1 = 0.0f;
float old_d2 = 0.0f;
float old_d3 = 0.0f;
Tile4FlashState state;
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF];
@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
}
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
state.row_states[0].d);
scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
state.row_states[1].d);
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
state.row_states[2].d);
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
state.row_states[3].d);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols());
position += kHTileSize;
@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]);
}
@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]);
}
@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]);
}
@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]);
}
++position;
}
return state;
}
// Rounds n to a number that can be used as the number of Q rows in a tile

View File

@ -20,35 +20,47 @@
#include <stddef.h>
#include <cstdint>
#include "gemma/flash_structs.h"
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the

23
gemma/flash_structs.h Normal file
View File

@ -0,0 +1,23 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <limits>
namespace gcpp {
struct OnlineSoftmaxState {
float max = -std::numeric_limits<float>::max() / 2.0f;
float d = 0.0f;
};
static constexpr size_t kVTileSize4 = 4;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_