mirror of https://github.com/google/gemma.cpp.git
Added access to flash attention internals to TileFlashAttention4
PiperOrigin-RevId: 826011137
This commit is contained in:
parent
ee7d79c0a6
commit
8a100c1e8d
|
|
@ -543,6 +543,7 @@ cc_library(
|
|||
"gemma/activations.h",
|
||||
"gemma/attention.h",
|
||||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ set(SOURCES
|
|||
gemma/configs.h
|
||||
gemma/flash_attention.cc
|
||||
gemma/flash_attention.h
|
||||
gemma/flash_structs.h
|
||||
gemma/gemma_args.h
|
||||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <limits>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
|
|
@ -444,16 +445,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention4(const MatPtrT<float>& q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
Tile4FlashState TileFlashAttention4(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
|
|
@ -467,14 +466,7 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||
}
|
||||
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_d0 = 0.0f;
|
||||
float old_d1 = 0.0f;
|
||||
float old_d2 = 0.0f;
|
||||
float old_d3 = 0.0f;
|
||||
Tile4FlashState state;
|
||||
size_t position = start_pos;
|
||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||
int32_t k_offsets[kMaxNF];
|
||||
|
|
@ -494,10 +486,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
|
||||
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
|
||||
}
|
||||
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
|
||||
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
|
||||
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
|
||||
scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max,
|
||||
state.row_states[0].d);
|
||||
scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max,
|
||||
state.row_states[1].d);
|
||||
scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max,
|
||||
state.row_states[2].d);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max,
|
||||
state.row_states[3].d);
|
||||
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
|
||||
out_offsets, v.Cols());
|
||||
position += kHTileSize;
|
||||
|
|
@ -516,7 +512,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x0 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap,
|
||||
state.row_states[0].max, state.row_states[0].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[0]);
|
||||
}
|
||||
|
|
@ -526,7 +523,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x1 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap,
|
||||
state.row_states[1].max, state.row_states[1].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[1]);
|
||||
}
|
||||
|
|
@ -536,7 +534,8 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x2 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap,
|
||||
state.row_states[2].max, state.row_states[2].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[2]);
|
||||
}
|
||||
|
|
@ -546,12 +545,14 @@ void TileFlashAttention4(const MatPtrT<float>& q,
|
|||
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
|
||||
float x3 =
|
||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap,
|
||||
state.row_states[3].max, state.row_states[3].d,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[3]);
|
||||
}
|
||||
++position;
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
// Rounds n to a number that can be used as the number of Q rows in a tile
|
||||
|
|
|
|||
|
|
@ -20,35 +20,47 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const float* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const float* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
|
||||
ThreadingContext& ctx, const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct OnlineSoftmaxState {
|
||||
float max = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float d = 0.0f;
|
||||
};
|
||||
|
||||
static constexpr size_t kVTileSize4 = 4;
|
||||
|
||||
struct Tile4FlashState {
|
||||
OnlineSoftmaxState row_states[kVTileSize4];
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
Loading…
Reference in New Issue