diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 33ad725..77a4480 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -51,6 +51,8 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +static constexpr size_t kNFx8HTileSize = 8; + // Transposes q into q_t. // Both are 4D tensors stuffed into a 2-D MatPtrT. // q has shape [batch, qbatch][head, qkv_dim]. @@ -191,7 +193,7 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { - constexpr size_t kHTileSize = 8; + constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -268,7 +270,7 @@ void TileFlashAttention( hwy::Profiler& p, const size_t worker) { static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); PROFILER_ZONE3(p, worker, zone); - constexpr int kHTileSize = 8; + constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; @@ -365,13 +367,12 @@ void TileFlashAttention( // Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. // This is the result of 4 rows of Q against NF K timesteps, with positions -// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read -// in consecutive elements, and other columns by adding q_stride. +// given by k_offsets[0..NF]. template > -void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, - hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, - VF& sum2, VF& sum3) { +void QDotKTilex4(DF df, const float* HWY_RESTRICT q, + const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, + const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p, + const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -383,15 +384,14 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, VI k_offsets_vec = hn::LoadU(di, k_offsets); for (size_t i = 0; i < k.Cols(); ++i) { VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, q[0]); + VF q_0 = hn::Set(df, q[q_offsets[0] + i]); sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, q[1]); + VF q_1 = hn::Set(df, q[q_offsets[1] + i]); sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, q[2]); + VF q_2 = hn::Set(df, q[q_offsets[2] + i]); sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, q[3]); + VF q_3 = hn::Set(df, q[q_offsets[3] + i]); sum3 = hn::MulAdd(q_3, k_vec, sum3); - q += q_stride; } } @@ -416,10 +416,9 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, // max_last_pos]. void TileFlashAttention4( const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, - const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, + const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, hwy::Profiler& p, const size_t worker) { @@ -445,8 +444,6 @@ void TileFlashAttention4( float old_d1 = 0.0f; float old_d2 = 0.0f; float old_d3 = 0.0f; - const float* HWY_RESTRICT qT_row = qT.Row(0); - const size_t qT_stride = qT.Stride(); size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { int32_t k_offsets[kMaxNF]; @@ -456,7 +453,8 @@ void TileFlashAttention4( k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); } VF x0, x1, x2, x3; - QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3); + QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2, + x3); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -608,12 +606,29 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) ? kNF : std::min(kMinTileSize, kMaxEqualK); - // q has shape [batch, qbatch][head, qkv_dim]. - // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the - // maximum possible number of consecutive columns have the same KV matrices. - // Each thread will process a tile of NF columns of QT so the starting column - // index of QT is just the task index * kVTileSize. - TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + // Only transpose Q if we are using tiling. + if (kVTileSize == kNF) { + size_t max_last = 0, min_start = std::numeric_limits::max(); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + size_t pos = qbatch.Pos(qi); + const size_t start = StartPos(pos, activations.config, layer_idx); + pos += num_tokens - 1; + const size_t end = qbatch.PrefixEnd(qi); + if (end > 0 && end - 1 > pos) { + pos = end - 1; + } + max_last = std::max(max_last, pos); + min_start = std::min(min_start, start); + } + if (max_last - min_start + 1 >= kNFx8HTileSize) { + // q has shape [batch, qbatch][head, qkv_dim]. + // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the + // maximum possible number of consecutive columns have the same KV + // matrices. Each thread will process a tile of NF columns of QT so the + // starting column index of QT is just the task index * kVTileSize. + TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + } + } const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); const hwy::Divisor div_tokens(num_tokens); // All layers should have the same number of heads. @@ -699,17 +714,23 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, StridedView(activations.q_T.Row(0) + first_task, kVTileSize, activations.q_T.Stride()); if (kVTileSize == kNF) { + // We can still use TileFlashAttention even if we didn't transpose Q + // above. The condition used for transposing Q above is more general + // and easier to compute than the condition used within + // TileFlashAttention that min_last_pos - start_positions[offset] < + // kNFx8HTileSize. In this case, qT is never used. Some tasks might + // use qT and some might not, which is why the more general condition + // is used above to catch all cases where qT will be used. TileFlashAttention(activations.q, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, max_last_pos, v, layer_idx, layer, activations, activations.att_out, out_offsets, ctx.profiler, worker); } else if (kVTileSize == 4) { - TileFlashAttention4(activations.q, q_offsets, qT, k, - start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, - worker); + TileFlashAttention4( + activations.q, q_offsets, k, start_positions[offset], last_pos, + min_last_pos, max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, worker); } else { HWY_UNREACHABLE; }