Avoid transposing Q when it isn't needed

PiperOrigin-RevId: 814187984
This commit is contained in:
Ray Smith 2025-10-02 05:16:03 -07:00 committed by Copybara-Service
parent fe5a39990e
commit 14244664c8
1 changed files with 52 additions and 31 deletions

View File

@ -51,6 +51,8 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
static constexpr size_t kNFx8HTileSize = 8;
// Transposes q into q_t. // Transposes q into q_t.
// Both are 4D tensors stuffed into a 2-D MatPtrT. // Both are 4D tensors stuffed into a 2-D MatPtrT.
// q has shape [batch, qbatch][head, qkv_dim]. // q has shape [batch, qbatch][head, qkv_dim].
@ -191,7 +193,7 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) { VF& sum7) {
constexpr size_t kHTileSize = 8; constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df); sum0 = hn::Zero(df);
sum1 = hn::Zero(df); sum1 = hn::Zero(df);
sum2 = hn::Zero(df); sum2 = hn::Zero(df);
@ -268,7 +270,7 @@ void TileFlashAttention(
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
PROFILER_ZONE3(p, worker, zone); PROFILER_ZONE3(p, worker, zone);
constexpr int kHTileSize = 8; constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -365,13 +367,12 @@ void TileFlashAttention(
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. // Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
// This is the result of 4 rows of Q against NF K timesteps, with positions // This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read // given by k_offsets[0..NF].
// in consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, const int32_t* HWY_RESTRICT k_offsets, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p,
VF& sum2, VF& sum3) { const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
sum0 = hn::Zero(df); sum0 = hn::Zero(df);
sum1 = hn::Zero(df); sum1 = hn::Zero(df);
sum2 = hn::Zero(df); sum2 = hn::Zero(df);
@ -383,15 +384,14 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
VI k_offsets_vec = hn::LoadU(di, k_offsets); VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) { for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[0]); VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
sum0 = hn::MulAdd(q_0, k_vec, sum0); sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[1]); VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
sum1 = hn::MulAdd(q_1, k_vec, sum1); sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[2]); VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
sum2 = hn::MulAdd(q_2, k_vec, sum2); sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[3]); VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
sum3 = hn::MulAdd(q_3, k_vec, sum3); sum3 = hn::MulAdd(q_3, k_vec, sum3);
q += q_stride;
} }
} }
@ -416,10 +416,9 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// max_last_pos]. // max_last_pos].
void TileFlashAttention4( void TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& k, const size_t start_pos,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t min_last_pos, const size_t max_last_pos, const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
@ -445,8 +444,6 @@ void TileFlashAttention4(
float old_d1 = 0.0f; float old_d1 = 0.0f;
float old_d2 = 0.0f; float old_d2 = 0.0f;
float old_d3 = 0.0f; float old_d3 = 0.0f;
const float* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos; size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) { while (position + kHTileSize - 1 <= min_last_pos) {
int32_t k_offsets[kMaxNF]; int32_t k_offsets[kMaxNF];
@ -456,7 +453,8 @@ void TileFlashAttention4(
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
} }
VF x0, x1, x2, x3; VF x0, x1, x2, x3;
QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3); QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
x3);
if (activations.config.att_cap > 0.0f) { if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap); VF cap = hn::Set(df, activations.config.att_cap);
@ -608,12 +606,29 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
? kNF ? kNF
: std::min(kMinTileSize, kMaxEqualK); : std::min(kMinTileSize, kMaxEqualK);
// Only transpose Q if we are using tiling.
if (kVTileSize == kNF) {
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
size_t pos = qbatch.Pos(qi);
const size_t start = StartPos(pos, activations.config, layer_idx);
pos += num_tokens - 1;
const size_t end = qbatch.PrefixEnd(qi);
if (end > 0 && end - 1 > pos) {
pos = end - 1;
}
max_last = std::max(max_last, pos);
min_start = std::min(min_start, start);
}
if (max_last - min_start + 1 >= kNFx8HTileSize) {
// q has shape [batch, qbatch][head, qkv_dim]. // q has shape [batch, qbatch][head, qkv_dim].
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
// maximum possible number of consecutive columns have the same KV matrices. // maximum possible number of consecutive columns have the same KV
// Each thread will process a tile of NF columns of QT so the starting column // matrices. Each thread will process a tile of NF columns of QT so the
// index of QT is just the task index * kVTileSize. // starting column index of QT is just the task index * kVTileSize.
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
}
}
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
const hwy::Divisor div_tokens(num_tokens); const hwy::Divisor div_tokens(num_tokens);
// All layers should have the same number of heads. // All layers should have the same number of heads.
@ -699,17 +714,23 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize, StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride()); activations.q_T.Stride());
if (kVTileSize == kNF) { if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q
// above. The condition used for transposing Q above is more general
// and easier to compute than the condition used within
// TileFlashAttention that min_last_pos - start_positions[offset] <
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k, TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, activations.att_out, out_offsets, ctx.profiler,
worker); worker);
} else if (kVTileSize == 4) { } else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, qT, k, TileFlashAttention4(
start_positions[offset], last_pos, min_last_pos, activations.q, q_offsets, k, start_positions[offset], last_pos,
max_last_pos, v, layer_idx, layer, activations, min_last_pos, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, activations.att_out, out_offsets, ctx.profiler, worker);
worker);
} else { } else {
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }