diff --git a/BUILD.bazel b/BUILD.bazel index 5f3bf87..50db088 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -548,7 +548,6 @@ cc_library( deps = [ ":basics", ":configs", - ":flash_structs", ":gemma_args", ":kv_cache", ":mat", @@ -596,11 +595,6 @@ cc_test( INTERNAL_DEPS = [] -cc_library( - name = "flash_structs", - hdrs = ["gemma/flash_structs.h"], -) - cc_library( name = "attention", srcs = [ @@ -610,6 +604,7 @@ cc_library( hdrs = [ "gemma/attention.h", "gemma/flash_attention.h", + "gemma/flash_structs.h", ], textual_hdrs = [ "gemma/gemma-inl.h", @@ -618,7 +613,6 @@ cc_library( ":activations", ":basics", ":configs", - ":flash_structs", ":kv_cache", ":mat", ":matmul", diff --git a/gemma/activations.h b/gemma/activations.h index ba0ceaf..3df61c5 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -24,7 +24,6 @@ #include #include "gemma/configs.h" // ModelConfig -#include "gemma/flash_structs.h" #include "gemma/gemma_args.h" // AttentionImpl #include "gemma/kv_cache.h" #include "gemma/tensor_stats.h" @@ -53,13 +52,10 @@ struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, - size_t max_workers, const Allocator& allocator, + const Allocator& allocator, std::vector>& row_ptrs) - : rep_factor(max_workers * - AttentionActivations::kThreadReplicationFactor / - layer_config.heads), - // `vocab_size == 0` means it is for Vit part, VitAttention - // is still MHA and does not use an external KV cache. + : // `vocab_size == 0` means it is for Vit part, VitAttention is still + // MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim @@ -90,9 +86,6 @@ struct AttentionActivations { att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), - att_out_reps(MatFactory("att_out", batch_size * rep_factor, - layer_config.heads * layer_config.qkv_dim, - allocator)), softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads, allocator)), softmax_d( @@ -114,11 +107,6 @@ struct AttentionActivations { } return; } - // This is a guess at the maximum number of params we might need to avoid - // reallocations. The actual number of params is determined by the number of - // query tiles, which is not known here. - flash_params.reserve(batch_size * layer_config.heads); - split_flash_params.reserve(batch_size * layer_config.heads); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but @@ -142,10 +130,6 @@ struct AttentionActivations { pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); - att_out_reps.OverrideRows(batch_size * rep_factor); - // There is no override for [split_]flash_params, because we reserved an - // upper bound, and flash attention controls the actual size when it - // calculates the size and number of tiles. softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); @@ -153,15 +137,6 @@ struct AttentionActivations { // `inv_timescale*` are not batched. } - // Maximum factor by which we might scale-up work to maximize parallelism. - size_t rep_factor = 1; - // Parameters for flash attention. The size of the vector is somewhere between - // the number of query rows and 1/8th of that. - std::vector flash_params; - // Parameters for flash attention, split by k-position. May be significantly - // larger than flash_params in decode mode, when the number of query rows is - // small. - std::vector split_flash_params; MatStorageT q; // query MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. @@ -173,7 +148,6 @@ struct AttentionActivations { MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output - MatStorageT att_out_reps; // attention output for each thread. MatStorageT softmax_max; // see OnlineSoftmaxState MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads @@ -190,26 +164,19 @@ struct AttentionActivations { // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; - // Replication factor to help evenly share work over threads. - static constexpr size_t kThreadReplicationFactor = 4; }; // A non-owning view of AttentionActivations. struct AttentionActivationsPtrs { - AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, - std::vector& flash_params, - std::vector& split_flash_params) + AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) : config(config), - flash_params(flash_params), - split_flash_params(split_flash_params), div_seq_len(static_cast(seq_len)), div_heads(static_cast(config.layer_configs[0].heads)), query_scale(ChooseQueryScale(config)) {} AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, - AttentionActivations& activations) - : AttentionActivationsPtrs(config, seq_len, activations.flash_params, - activations.split_flash_params) { + const AttentionActivations& activations) + : AttentionActivationsPtrs(config, seq_len) { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; @@ -219,7 +186,6 @@ struct AttentionActivationsPtrs { pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; - att_out_reps = activations.att_out_reps; softmax_max = activations.softmax_max; softmax_d = activations.softmax_d; att_sums = activations.att_sums; @@ -250,9 +216,6 @@ struct AttentionActivationsPtrs { } const ModelConfig& config; - // Parameters for flash attention. - std::vector& flash_params; - std::vector& split_flash_params; // For the matrices below, the batch_size dimension is really qbatch.Size() * // token_batch_size, but in all known uses, one of those is 1. Specifically, @@ -278,7 +241,6 @@ struct AttentionActivationsPtrs { // Attention output computed from att * V, size batch_size x (q_heads * // qkv_dim). MatPtrT att_out; - MatPtrT att_out_reps; // The maximum logit value encountered when computing att_out from att, // size batch_size x q_heads . See OnlineSoftmaxState for details. // WARNING: Only filled in for AttentionImpl::kOld. @@ -343,8 +305,7 @@ struct Activations { s_w_linear_w(config.num_layers, max_workers), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, - runtime_config, ctx.pools.MaxWorkers(), ctx.allocator, - row_ptrs), + runtime_config, ctx.allocator, row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); diff --git a/gemma/attention.cc b/gemma/attention.cc index 570c4f4..8ea9b6d 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -49,39 +49,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -// Returns the number of floats per vector (aka NF). -size_t FloatsPerVector() { - using DF = hn::ScalableTag; - const DF df; - return hn::Lanes(df); -} - -// The k-cache and v-cache are setup without knowing NF. So if it hasn't been -// done already, reshape it to take NF into account. -void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache) { - if (kv.Cols() > cache.Cols()) { - cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); - } -} - -// Transposes a single row of the kv cache into the k-cache and v-cache. -void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, - KV_t* HWY_RESTRICT v, size_t qkv_dim) { - // This is inefficient, as the writes are scattered over cache lines, but it - // is a tiny fraction of the overall computation, and it is linear in the - // token length. - const size_t kFloatsPerTile = 2 * FloatsPerVector(); - for (size_t i = 0; i < qkv_dim; i += 2) { - k[i * kFloatsPerTile] = kv[i]; - k[i * kFloatsPerTile + 1] = kv[i + 1]; - } - for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) { - for (size_t j = 0; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; - } - } -} - // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -313,11 +280,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache); - MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache); - } - const size_t kFloatsPerVector = FloatsPerVector(); // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the @@ -337,26 +299,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + head * qkv_dim * 2; - // Note that k_cache and v_cache are different shapes. - // The innermost dimension of k is 2 values from qkv_dim because they - // are going to be used in a BF16 dot product involving pairs of - // values over NF k positions. - // The innermost dimension of v is 2NF values from qkv_dim because they - // will be loaded into a BF16 vector to be scaled and added to the - // cached attention output in 2 NF-sized registers. - // TODO(rays): factor out these calculations into functions. - auto& k_cache = qbatch.KV(qi).k_cache; - KV_t* HWY_RESTRICT k = - k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - (layer_idx * cache_layer_size + head * qkv_dim * 2) * - kFloatsPerVector + - (cache_pos % (2 * kFloatsPerVector)) * 2; - auto& v_cache = qbatch.KV(qi).v_cache; - KV_t* HWY_RESTRICT v = - v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - (layer_idx * cache_layer_size + head * qkv_dim * 2) * - kFloatsPerVector + - (cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector; HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; const hn::ScalableTag df; @@ -377,10 +319,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); - // This is inefficient, as multiple threads are writing the same K - // cache line, but the input is generated by a matmul, so it is - // difficult to change, and it probably isn't significant. - TransposeKVCacheRow(kv, k, v, qkv_dim); }); } @@ -403,8 +341,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, } else { // * 2 does not help on Turin. FlashAttention(num_tokens, - /*target_parallelism=*/env.ctx.pools.MaxWorkers() * - AttentionActivations::kThreadReplicationFactor, + /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, layer_idx, layer.query_norm_scale, activations, qbatch, env.ctx, attention_impl); } diff --git a/gemma/attention.h b/gemma/attention.h index 14870de..7fb958f 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -31,13 +31,6 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - size_t FloatsPerVector(); \ - \ - void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache); \ - \ - void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \ - KV_t* HWY_RESTRICT v, size_t qkv_dim); \ - \ void PositionalEncodingQK(float* qk, size_t layer_idx, \ const AttentionActivationsPtrs& activations, \ ThreadingContext& ctx, size_t worker, size_t pos, \ diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc index f7193ac..fe57a5b 100644 --- a/gemma/attention_test.cc +++ b/gemma/attention_test.cc @@ -1,10 +1,8 @@ #include -#include #include // strcmp #include #include #include -#include #include #include "gtest/gtest.h" @@ -107,8 +105,7 @@ struct TestAttentionState { tokens(num_tokens), attention_storage_(model_state.config, model_state.layer_config, batch_size, num_tokens, runtime_config, - state.ctx.pools.MaxWorkers(), state.ctx.allocator, - row_ptrs_), + state.ctx.allocator, row_ptrs_), attention(model_state.config, num_tokens, attention_storage_) { for (size_t i = 0; i < qbatch_size; ++i) { kv_caches.emplace_back(model_state.config, inference_args, @@ -146,7 +143,6 @@ struct TestAttentionState { }; double GetTolerance() { - if (IsBF16()) return 1e-2; const char* target_name = hwy::TargetName(HWY_TARGET); if (strncmp(target_name, "AVX2", 4) == 0) { return 2e-2; @@ -159,20 +155,6 @@ double GetTolerance() { } } -template -bool CompareArraySimilar(const T* expected, const T* actual, size_t count, - const char* target_name, const char* filename, - int line) { - if constexpr (IsBF16()) { - constexpr double kTolerance = 3e-2; - return hwy::CompareArraySimilar(expected, actual, count, kTolerance, - target_name, filename, line); - } else { - return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(), - target_name, filename, line); - } -} - template void CompareAttSumsWithGolden( const AttentionActivationsPtrs& attention, @@ -188,9 +170,9 @@ void CompareAttSumsWithGolden( for (size_t j = 0; j < kDims; ++j) { actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]); } - EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(), - kDims, hwy::TargetName(HWY_TARGET), - __FILE__, __LINE__)) + EXPECT_TRUE(hwy::CompareArraySimilar( + golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(), + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi; } } @@ -218,20 +200,19 @@ void CompareKVCacheWithGolden( for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) { for (size_t qi = 0; qi < kQBatchSize; ++qi) { - const BF16* cache_row = + const float* cache_row = kv_caches[qi].kv_cache.Row(start_offset + token_idx); for (size_t j = 0; j < kDims; ++j) { - actual_k_row[j] = hwy::ConvertScalarTo(cache_row[kv_offset + j]); - actual_v_row[j] = - hwy::ConvertScalarTo(cache_row[kv_offset + qkv_dim + j]); + actual_k_row[j] = cache_row[kv_offset + j]; + actual_v_row[j] = cache_row[kv_offset + qkv_dim + j]; } - EXPECT_TRUE(CompareArraySimilar( - k_golden[token_idx][qi], actual_k_row.get(), kDims, + EXPECT_TRUE(hwy::CompareArraySimilar( + k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(), hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "K cache mismatch for token_idx=" << token_idx << " qi=" << qi << " kv_head=" << kv_head; - EXPECT_TRUE(CompareArraySimilar( - v_golden[token_idx][qi], actual_v_row.get(), kDims, + EXPECT_TRUE(hwy::CompareArraySimilar( + v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(), hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "V cache mismatch for token_idx=" << token_idx << " qi=" << qi << " kv_head=" << kv_head; @@ -257,8 +238,8 @@ void CompareQVecsWithGolden( for (size_t j = 0; j < kDims; ++j) { actual_q_row[j] = q_row[head_offset + j]; } - EXPECT_TRUE(CompareArraySimilar( - q_golden[token_idx][qi], actual_q_row.get(), kDims, + EXPECT_TRUE(hwy::CompareArraySimilar( + q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(), hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi << " q_head=" << q_head; @@ -282,46 +263,46 @@ const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats // Layer 0 const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = { - {{46.5, 56.5, 10.0625, 65.5, -2.239375, 135, 15.8125, 51, -100, 52.5, + {{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5, 26.875, 63, 3.34375, -67.5, 31.125, -190, 125}, {-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125, -14.9375, -109.5, 76, 9.25, -142, 29.5, -105}}, - {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875, + {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875, 128, 27.25, -161, 19.125, -58, 97.5}, - {-17.625, -15.375, 135, -13.4375, -3.343, -45.75, 29.625, 93, 18.625, 75.5, + {-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5, 102.5, -184, 52.75, 83.5, -71, 46.5, -52}}, - {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125, - 6.4625, 150, 144, -155, -47.25, -98.5, 3.5625}, - {-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, - -66.5, -0.62165625, -46.5, -152, -2.9375, -81}}, - {{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75, + {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125, + 6.90625, 150, 144, -155, -47.25, -98.5, 3.5625}, + {-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, + -66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, + {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, -46, -22, -19.375, -16.125, -148, 20.875}, - {-47, -17.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, + {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, 10.9375, -52.5, 23.25, 75}}, - {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25, + {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, -34.75, 18, -52, 100, -186, -75.5, 50.75}, - {7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.2675, 82.5, + {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, 39.25, 65, 47.25, -89.5, -34.25, 137}}, - {{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -12.625, 38.5, 32.5, - 53.75, 109, 4.62375, 57.5, -20.5, 132}, - {143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5, - 195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}}, + {{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5, + 32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132}, + {143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5, + 13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}}, {{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25, -117, 26.375, 39.5, 125, 66, 48.75, 20.75}, - {137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5, - 88.5, 48.5, -46.25, -35.5, 101.5}}, - {{-103, -41.5, 39, -52, -62.7, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5, - 125, -53.5, -11.625, 262}, - {28.075, 6.64375, -36.75, -13.35, -27.5, 44.75, -67.5, -40.75, 71.5, 172, - 81, -28.5, -3.875, 111, -167, 59, 176}}, + {137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5, + 88.5, 48.5, -46.25, -36.75, 101.5}}, + {{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5, + 125, -53.5, -14.625, 262}, + {29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172, + 81, -27.25, -3.03125, 111, -167, 59, 176}}, {{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25, -52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5}, - {40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5, + {40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5, 58.25, 173, -53.5, 25.125, 4.8125}}, {{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5, -17.125, 15.25, 143, -27, 103, 101}, - {-77, 40.75, -10.5, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625, - 8.55, -99.5, 14.6875, -12.25, 33}}, + {-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625, + 8.125, -99.5, 13.6875, -11.6875, 33}}, }; // Layer 0, *K*V Head 0 diff --git a/gemma/configs.h b/gemma/configs.h index 3811e98..53b19b9 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) { } enum class AttentionImpl { - kOld, // Previous Attention implementation - kFlash, // Flash Attention (default) + kOld, + kFlash, kFlashTransposedQs, kFlashTransposedQsBF16, kSentinel, diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 360dc9d..ebe8ee1 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -59,7 +58,43 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +static constexpr size_t kNFx8HTileSize = 8; static constexpr float kNegInf = -std::numeric_limits::max() / 64.0f; +// Transposes q into q_t. +// Both are 4D tensors stuffed into a 2-D MatPtrT. +// q has shape [batch, qbatch][head, qkv_dim]. +// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum +// possible consecutive elements have the same KV. +static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, + const size_t qbatch_size, ThreadingContext& ctx) { + // Group floats by the number of floats in a cache line. + const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); + const size_t num_heads = q.Cols() / q_t.Rows(); + const size_t batch_size = q.Rows() / qbatch_size; + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ); + for (size_t lane = 0; lane < kNF; ++lane) { + size_t q_row = task * kNF + lane; + if (q_row >= q_t.Rows()) break; + BF16* HWY_RESTRICT qt_row = q_t.Row(q_row); + for (size_t qi = 0; qi < qbatch_size; ++qi) { + for (size_t h = 0; h < num_heads; ++h) { + for (size_t b = 0; b < batch_size; ++b) { + qt_row[(qi * num_heads + h) * batch_size + b] = + hwy::ConvertScalarTo( + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]); + } + } + } + } + }; + { + const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); + // Better than kFlat. + ParallelFor(Parallelism::kHierarchical, num_tasks, ctx, + /*cluster_idx=*/0, Callers::kFlashTransposeQ, func); + } +} // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, @@ -102,390 +137,292 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } } -// Zeroes out kVTileSize of the given vectors. -template > -HWY_INLINE void ZeroResults(DF df, VF& sum0, VF& HWY_MAYBE_UNUSED sum1, - VF& HWY_MAYBE_UNUSED sum2, - VF& HWY_MAYBE_UNUSED sum3, - VF& HWY_MAYBE_UNUSED sum4, - VF& HWY_MAYBE_UNUSED sum5, - VF& HWY_MAYBE_UNUSED sum6, - VF& HWY_MAYBE_UNUSED sum7) { +// Handles a single v row of flash attention for a single q.k dot product. +HWY_INLINE void SingleFlashAttentionStep(float x, float cap, float& old_max, + float& old_d, + const float* HWY_RESTRICT v, + const size_t v_cols, + float* HWY_RESTRICT att_out) { + if (cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + x = cap * std::tanh(x / cap); + } + float m = std::max(x, old_max); + x = std::exp(x - m); + float scale = old_d * std::exp(old_max - m); + old_d = x + scale; + old_max = m; + float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x *= one_over_d; + MulByConst(scale, att_out, v_cols); + MulByConstAndAdd(x, v, att_out, v_cols); +} + +// Calculates the complete attention outputs for a single row of q. +void SingleFlashAttention(const size_t start_pos, const size_t last_pos, + const BF16* HWY_RESTRICT q, const MatPtrT& k, + const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, + float* HWY_RESTRICT att_out, ThreadingContext& ctx, + const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + + const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); + // TODO: Mixed-mode can be further improved for Turin: we can demote right + // before we do the dot product instruction, rather than promote both to f32. + // But some potential accuracy loss there, needs evaluation first. + float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); + if (float cap = activations.config.att_cap; cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + m = cap * std::tanh(m / cap); + } + float d = 1.0f; + // This is just a copy of the first token. + MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + const size_t pos_mod = activations.div_seq_len.Remainder(pos); + float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); + SingleFlashAttentionStep(x, activations.config.att_cap, m, d, + v.Row(pos_mod), v.Cols(), att_out); + } +} + +// Computes and returns a single vector of NF Q.K dot products, which represents +// the dot products of NF rows of Q for a single K timestep. +template > +VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, + const size_t k_pos, const MatPtrT& q, + const MatPtrT& k) { + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + + hn::TFromD results[hn::MaxLanes(df)]; + for (size_t i = 0; i < hn::Lanes(df); ++i) { + results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + } + return hn::LoadU(df, results); +} + +// Returns an NF Q rows by 8 K rows tile of Q.K dot products. +// This is the result of NF rows of Q against 8 K timesteps, with positions +// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in +// consecutive elements, and other columns by adding q_stride. +template > +void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const size_t* k_pos, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { + constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); - if constexpr (kVTileSize >= 4) { - sum1 = hn::Zero(df); - sum2 = hn::Zero(df); - sum3 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + sum4 = hn::Zero(df); + sum5 = hn::Zero(df); + sum6 = hn::Zero(df); + sum7 = hn::Zero(df); + const float* HWY_RESTRICT k_row[kHTileSize]; + for (size_t i = 0; i < kHTileSize; ++i) { + k_row[i] = k.Row(k_pos[i]); } - if constexpr (kVTileSize >= 8) { - sum4 = hn::Zero(df); - sum5 = hn::Zero(df); - sum6 = hn::Zero(df); - sum7 = hn::Zero(df); + + const hn::Rebind dbfh; + using VBF = hn::Vec; + + for (size_t i = 0; i < k.Cols(); ++i) { + const VBF q_vec_bf = hn::Load(dbfh, q); + const VF q_vec = hn::PromoteTo(df, q_vec_bf); + VF k_0 = hn::Set(df, k_row[0][i]); + sum0 = hn::MulAdd(q_vec, k_0, sum0); + VF k_1 = hn::Set(df, k_row[1][i]); + sum1 = hn::MulAdd(q_vec, k_1, sum1); + VF k_2 = hn::Set(df, k_row[2][i]); + sum2 = hn::MulAdd(q_vec, k_2, sum2); + VF k_3 = hn::Set(df, k_row[3][i]); + sum3 = hn::MulAdd(q_vec, k_3, sum3); + VF k_4 = hn::Set(df, k_row[4][i]); + sum4 = hn::MulAdd(q_vec, k_4, sum4); + VF k_5 = hn::Set(df, k_row[5][i]); + sum5 = hn::MulAdd(q_vec, k_5, sum5); + VF k_6 = hn::Set(df, k_row[6][i]); + sum6 = hn::MulAdd(q_vec, k_6, sum6); + VF k_7 = hn::Set(df, k_row[7][i]); + sum7 = hn::MulAdd(q_vec, k_7, sum7); + q += q_stride; } } -// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. -// K is always pre-transposed to shape: -// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 -// represents that pairs of qkv_dim elements are kept together to make best use -// of BF16 dot product instructions. -// Note that this version assumes that Q is float32, and not transposed, and -// HWY_NATIVE_DOT_BF16 is false. -template > -HWY_INLINE void QDotKTile148FloatNotNative( - DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, - size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, - VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, - VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, - VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, - VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, - VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, - VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, - VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { - ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, - sum70); - ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, - sum71); - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - const float* HWY_RESTRICT q_base[kVTileSize]; - for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = q + q_offsets[i]; - } - const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); - for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { - // TODO(rays): Replace with decompress2. - VBF k0_vec = hn::LoadU(dbf, k_base); - VBF k1_vec = hn::LoadU(dbf, k_base + kNF * 2); - VF k0_even = hn::PromoteEvenTo(df, k0_vec); - VF k0_odd = hn::PromoteOddTo(df, k0_vec); - VF k1_even = hn::PromoteEvenTo(df, k1_vec); - VF k1_odd = hn::PromoteOddTo(df, k1_vec); - VF q0_even = hn::Set(df, q_base[0][i * 2]); - VF q0_odd = hn::Set(df, q_base[0][i * 2 + 1]); - sum00 = hn::MulAdd(q0_even, k0_even, sum00); - sum01 = hn::MulAdd(q0_even, k1_even, sum01); - sum00 = hn::MulAdd(q0_odd, k0_odd, sum00); - sum01 = hn::MulAdd(q0_odd, k1_odd, sum01); - if constexpr (kVTileSize >= 4) { - VF q1_even = hn::Set(df, q_base[1][i * 2]); - VF q1_odd = hn::Set(df, q_base[1][i * 2 + 1]); - sum10 = hn::MulAdd(q1_even, k0_even, sum10); - sum11 = hn::MulAdd(q1_even, k1_even, sum11); - sum10 = hn::MulAdd(q1_odd, k0_odd, sum10); - sum11 = hn::MulAdd(q1_odd, k1_odd, sum11); - VF q2_even = hn::Set(df, q_base[2][i * 2]); - VF q2_odd = hn::Set(df, q_base[2][i * 2 + 1]); - sum20 = hn::MulAdd(q2_even, k0_even, sum20); - sum21 = hn::MulAdd(q2_even, k1_even, sum21); - sum20 = hn::MulAdd(q2_odd, k0_odd, sum20); - sum21 = hn::MulAdd(q2_odd, k1_odd, sum21); - VF q3_even = hn::Set(df, q_base[3][i * 2]); - VF q3_odd = hn::Set(df, q_base[3][i * 2 + 1]); - sum30 = hn::MulAdd(q3_even, k0_even, sum30); - sum31 = hn::MulAdd(q3_even, k1_even, sum31); - sum30 = hn::MulAdd(q3_odd, k0_odd, sum30); - sum31 = hn::MulAdd(q3_odd, k1_odd, sum31); - } - if constexpr (kVTileSize >= 8) { - VF q4_even = hn::Set(df, q_base[4][i * 2]); - VF q4_odd = hn::Set(df, q_base[4][i * 2 + 1]); - sum40 = hn::MulAdd(q4_even, k0_even, sum40); - sum41 = hn::MulAdd(q4_even, k1_even, sum41); - sum40 = hn::MulAdd(q4_odd, k0_odd, sum40); - sum41 = hn::MulAdd(q4_odd, k1_odd, sum41); - VF q5_even = hn::Set(df, q_base[5][i * 2]); - VF q5_odd = hn::Set(df, q_base[5][i * 2 + 1]); - sum50 = hn::MulAdd(q5_even, k0_even, sum50); - sum51 = hn::MulAdd(q5_even, k1_even, sum51); - sum50 = hn::MulAdd(q5_odd, k0_odd, sum50); - sum51 = hn::MulAdd(q5_odd, k1_odd, sum51); - VF q6_even = hn::Set(df, q_base[6][i * 2]); - VF q6_odd = hn::Set(df, q_base[6][i * 2 + 1]); - sum60 = hn::MulAdd(q6_even, k0_even, sum60); - sum61 = hn::MulAdd(q6_even, k1_even, sum61); - sum60 = hn::MulAdd(q6_odd, k0_odd, sum60); - sum61 = hn::MulAdd(q6_odd, k1_odd, sum61); - VF q7_even = hn::Set(df, q_base[7][i * 2]); - VF q7_odd = hn::Set(df, q_base[7][i * 2 + 1]); - sum70 = hn::MulAdd(q7_even, k0_even, sum70); - sum71 = hn::MulAdd(q7_even, k1_even, sum71); - sum70 = hn::MulAdd(q7_odd, k0_odd, sum70); - sum71 = hn::MulAdd(q7_odd, k1_odd, sum71); - } - } +// Returns the element-wise maximum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF m0 = hn::Max(x0, x1); + VF m1 = hn::Max(x2, x3); + VF m2 = hn::Max(x4, x5); + VF m3 = hn::Max(x6, x7); + m0 = hn::Max(m0, m1); + m2 = hn::Max(m2, m3); + return hn::Max(m0, m2); } -// Loads an adjacent pair of floats, converts them to BF16, and broadcasts them -// across a vector of BF16 as alternating odd and even elements. -// hn::ReorderDemote2To(dbf, q_1_float, q_1_float); with q1_float containing -// alternating odd and even floats appears not to do this. -HWY_INLINE hn::Vec> DemoteAndBroadcast2ToBF16( - const float* HWY_RESTRICT base) { +// Returns the element-wise sum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF sum0 = hn::Add(x0, x1); + VF sum1 = hn::Add(x2, x3); + VF sum2 = hn::Add(x4, x5); + VF sum3 = hn::Add(x6, x7); + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + return hn::Add(sum0, sum2); +} + +// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to +// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, +// max_last_pos]. +void TileFlashAttention( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); + constexpr size_t kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; - VF v_even = hn::Set(df, base[0]); - VF v_odd = hn::Set(df, base[1]); - VF interleaved = hn::OddEven(v_odd, v_even); - return hn::OrderedDemote2To(hn::ScalableTag(), interleaved, - interleaved); -} - -// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. -// K is always pre-transposed to shape: -// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 -// represents that pairs of qkv_dim elements are kept together to make best use -// of BF16 dot product instructions. -// Note that this version assumes that Q is float32, and not transposed, and -// HWY_NATIVE_DOT_BF16 is true. -template > -HWY_INLINE void QDotKTile148FloatNative( - DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, - size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, - VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, - VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, - VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, - VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, - VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, - VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, - VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { - ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, - sum70); - ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, - sum71); - VF unused = hn::Zero(df); - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - const float* HWY_RESTRICT q_base[kVTileSize]; + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + const size_t kVTileSize = hn::Lanes(df); for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = q + q_offsets[i]; + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); } - const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); - for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { - VBF kvec0 = hn::LoadU(dbf, k_base); - VBF kvec1 = hn::LoadU(dbf, k_base + kNF * 2); - VBF q0_bf16 = DemoteAndBroadcast2ToBF16(q_base[0] + i * 2); - sum00 = hn::ReorderWidenMulAccumulate(df, q0_bf16, kvec0, sum00, unused); - sum01 = hn::ReorderWidenMulAccumulate(df, q0_bf16, kvec1, sum01, unused); - if constexpr (kVTileSize >= 4) { - VBF q1_bf16 = DemoteAndBroadcast2ToBF16(q_base[1] + i * 2); - sum10 = hn::ReorderWidenMulAccumulate(df, q1_bf16, kvec0, sum10, unused); - sum11 = hn::ReorderWidenMulAccumulate(df, q1_bf16, kvec1, sum11, unused); - VBF q2_bf16 = DemoteAndBroadcast2ToBF16(q_base[2] + i * 2); - sum20 = hn::ReorderWidenMulAccumulate(df, q2_bf16, kvec0, sum20, unused); - sum21 = hn::ReorderWidenMulAccumulate(df, q2_bf16, kvec1, sum21, unused); - VBF q3_bf16 = DemoteAndBroadcast2ToBF16(q_base[3] + i * 2); - sum30 = hn::ReorderWidenMulAccumulate(df, q3_bf16, kvec0, sum30, unused); - sum31 = hn::ReorderWidenMulAccumulate(df, q3_bf16, kvec1, sum31, unused); + VI lasts = hn::LoadU(di, last_pos); + VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); + VF old_d = hn::Zero(df); + const BF16* HWY_RESTRICT qT_row = qT.Row(0); + const size_t qT_stride = qT.Stride(); + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + size_t k_pos[kHTileSize]; + for (size_t i = 0; i < kHTileSize; ++i) { + k_pos[i] = activations.div_seq_len.Remainder(position + i); } - if constexpr (kVTileSize >= 8) { - VBF q4_bf16 = DemoteAndBroadcast2ToBF16(q_base[4] + i * 2); - sum40 = hn::ReorderWidenMulAccumulate(df, q4_bf16, kvec0, sum40, unused); - sum41 = hn::ReorderWidenMulAccumulate(df, q4_bf16, kvec1, sum41, unused); - VBF q5_bf16 = DemoteAndBroadcast2ToBF16(q_base[5] + i * 2); - sum50 = hn::ReorderWidenMulAccumulate(df, q5_bf16, kvec0, sum50, unused); - sum51 = hn::ReorderWidenMulAccumulate(df, q5_bf16, kvec1, sum51, unused); - VBF q6_bf16 = DemoteAndBroadcast2ToBF16(q_base[6] + i * 2); - sum60 = hn::ReorderWidenMulAccumulate(df, q6_bf16, kvec0, sum60, unused); - sum61 = hn::ReorderWidenMulAccumulate(df, q6_bf16, kvec1, sum61, unused); - VBF q7_bf16 = DemoteAndBroadcast2ToBF16(q_base[7] + i * 2); - sum70 = hn::ReorderWidenMulAccumulate(df, q7_bf16, kvec0, sum70, unused); - sum71 = hn::ReorderWidenMulAccumulate(df, q7_bf16, kvec1, sum71, unused); + VF x0, x1, x2, x3, x4, x5, x6, x7; + QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap))); + x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap))); + x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap))); + x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap))); } + VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + m = hn::Max(old_m, m); + x0 = hn::Exp(df, hn::Sub(x0, m)); + x1 = hn::Exp(df, hn::Sub(x1, m)); + x2 = hn::Exp(df, hn::Sub(x2, m)); + x3 = hn::Exp(df, hn::Sub(x3, m)); + x4 = hn::Exp(df, hn::Sub(x4, m)); + x5 = hn::Exp(df, hn::Sub(x5, m)); + x6 = hn::Exp(df, hn::Sub(x6, m)); + x7 = hn::Exp(df, hn::Sub(x7, m)); + VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); + old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + old_d = hn::Add(scale, old_d); + old_m = m; + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + scale = hn::Mul(scale, one_over_d); + x0 = hn::Mul(x0, one_over_d); + x1 = hn::Mul(x1, one_over_d); + x2 = hn::Mul(x2, one_over_d); + x3 = hn::Mul(x3, one_over_d); + x4 = hn::Mul(x4, one_over_d); + x5 = hn::Mul(x5, one_over_d); + x6 = hn::Mul(x6, one_over_d); + x7 = hn::Mul(x7, one_over_d); + MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, + att_out.Row(0), out_offsets, v.Cols()); + position += kHTileSize; + } + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + VF x0 = QDotKVector(df, q_offsets, k_pos, q, k); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + } + // Past the last position, x0 doesn't count. + auto mask = hn::Gt(hn::Set(di, position), lasts); + VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask), + std::numeric_limits::max() / 2.0f); + x0 = hn::Sub(x0, causal_offset); + VF m = hn::Max(old_m, x0); + x0 = hn::Exp(df, hn::Sub(x0, m)); + VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); + old_m = m; + old_d = hn::Add(scale, x0); + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + x0 = hn::Mul(x0, one_over_d); + scale = hn::Mul(scale, one_over_d); + MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, + v.Cols()); + ++position; } } -// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. -// K is always pre-transposed to shape: -// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 -// represents that pairs of qkv_dim elements are kept together to make best use -// of BF16 dot product instructions. -// Note that this is optimized for the case where q and k are bf16, but there is -// no native_bf16 instruction. -template > -HWY_INLINE void QDotKTile148BF16NotNative( - DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, - size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, - VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, - VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, - VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, - VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, - VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, - VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, - VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { - ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, - sum70); - ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, - sum71); - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - const float* HWY_RESTRICT q_base[kVTileSize]; - for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = reinterpret_cast(q + q_offsets[i]); - } - const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); - for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { - VBF kvec0 = hn::LoadU(dbf, k_base); - VBF kvec1 = hn::LoadU(dbf, k_base + kNF * 2); - VBF q0 = hn::BitCast(dbf, hn::Set(df, q_base[0][i])); - VF k0_even = hn::PromoteEvenTo(df, kvec0); - VF k0_odd = hn::PromoteOddTo(df, kvec0); - VF k1_even = hn::PromoteEvenTo(df, kvec1); - VF k1_odd = hn::PromoteOddTo(df, kvec1); - VF q0_even = hn::PromoteEvenTo(df, q0); - sum00 = hn::MulAdd(q0_even, k0_even, sum00); - sum01 = hn::MulAdd(q0_even, k1_even, sum01); - VF q0_odd = hn::PromoteOddTo(df, q0); - sum00 = hn::MulAdd(q0_odd, k0_odd, sum00); - sum01 = hn::MulAdd(q0_odd, k1_odd, sum01); - if constexpr (kVTileSize >= 4) { - VBF q1 = hn::BitCast(dbf, hn::Set(df, q_base[1][i])); - VF q1_even = hn::PromoteEvenTo(df, q1); - sum10 = hn::MulAdd(q1_even, k0_even, sum10); - sum11 = hn::MulAdd(q1_even, k1_even, sum11); - VF q1_odd = hn::PromoteOddTo(df, q1); - sum10 = hn::MulAdd(q1_odd, k0_odd, sum10); - sum11 = hn::MulAdd(q1_odd, k1_odd, sum11); - VBF q2 = hn::BitCast(dbf, hn::Set(df, q_base[2][i])); - VF q2_even = hn::PromoteEvenTo(df, q2); - sum20 = hn::MulAdd(q2_even, k0_even, sum20); - sum21 = hn::MulAdd(q2_even, k1_even, sum21); - VF q2_odd = hn::PromoteOddTo(df, q2); - sum20 = hn::MulAdd(q2_odd, k0_odd, sum20); - sum21 = hn::MulAdd(q2_odd, k1_odd, sum21); - VBF q3 = hn::BitCast(dbf, hn::Set(df, q_base[3][i])); - VF q3_even = hn::PromoteEvenTo(df, q3); - sum30 = hn::MulAdd(q3_even, k0_even, sum30); - sum31 = hn::MulAdd(q3_even, k1_even, sum31); - VF q3_odd = hn::PromoteOddTo(df, q3); - sum30 = hn::MulAdd(q3_odd, k0_odd, sum30); - sum31 = hn::MulAdd(q3_odd, k1_odd, sum31); - } - if constexpr (kVTileSize >= 8) { - VBF q4 = hn::BitCast(dbf, hn::Set(df, q_base[4][i])); - VF q4_even = hn::PromoteEvenTo(df, q4); - sum40 = hn::MulAdd(q4_even, k0_even, sum40); - sum41 = hn::MulAdd(q4_even, k1_even, sum41); - VF q4_odd = hn::PromoteOddTo(df, q4); - sum40 = hn::MulAdd(q4_odd, k0_odd, sum40); - sum41 = hn::MulAdd(q4_odd, k1_odd, sum41); - VBF q5 = hn::BitCast(dbf, hn::Set(df, q_base[5][i])); - VF q5_even = hn::PromoteEvenTo(df, q5); - sum50 = hn::MulAdd(q5_even, k0_even, sum50); - sum51 = hn::MulAdd(q5_even, k1_even, sum51); - VF q5_odd = hn::PromoteOddTo(df, q5); - sum50 = hn::MulAdd(q5_odd, k0_odd, sum50); - sum51 = hn::MulAdd(q5_odd, k1_odd, sum51); - VBF q6 = hn::BitCast(dbf, hn::Set(df, q_base[6][i])); - VF q6_even = hn::PromoteEvenTo(df, q6); - sum60 = hn::MulAdd(q6_even, k0_even, sum60); - sum61 = hn::MulAdd(q6_even, k1_even, sum61); - VF q6_odd = hn::PromoteOddTo(df, q6); - sum60 = hn::MulAdd(q6_odd, k0_odd, sum60); - sum61 = hn::MulAdd(q6_odd, k1_odd, sum61); - VBF q7 = hn::BitCast(dbf, hn::Set(df, q_base[7][i])); - VF q7_even = hn::PromoteEvenTo(df, q7); - sum70 = hn::MulAdd(q7_even, k0_even, sum70); - sum71 = hn::MulAdd(q7_even, k1_even, sum71); - VF q7_odd = hn::PromoteOddTo(df, q7); - sum70 = hn::MulAdd(q7_odd, k0_odd, sum70); - sum71 = hn::MulAdd(q7_odd, k1_odd, sum71); - } - } -} - -// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. -// K is always pre-transposed to shape: -// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 -// represents that pairs of qkv_dim elements are kept together to make best use -// of BF16 dot product instructions. -// Note that this is optimized for the case where q and k are bf16, and there is -// a native_bf16 instruction. -template > -HWY_INLINE void QDotKTile148BF16Native( - DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, - size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, - VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, - VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, - VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, - VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, - VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, - VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, - VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { - ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, - sum70); - ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, - sum71); - VF unused_sum1 = hn::Zero(df); - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - const float* HWY_RESTRICT q_base[kVTileSize]; - for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = reinterpret_cast(q + q_offsets[i]); - } - const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); - for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { - VBF k0_vec = hn::LoadU(dbf, k_base); - VBF k1_vec = hn::LoadU(dbf, k_base + kNF * 2); - VBF q0 = hn::BitCast(dbf, hn::Set(df, q_base[0][i])); - sum00 = hn::ReorderWidenMulAccumulate(df, q0, k0_vec, sum00, unused_sum1); - sum01 = hn::ReorderWidenMulAccumulate(df, q0, k1_vec, sum01, unused_sum1); - if constexpr (kVTileSize >= 4) { - VBF q1 = hn::BitCast(dbf, hn::Set(df, q_base[1][i])); - sum10 = hn::ReorderWidenMulAccumulate(df, q1, k0_vec, sum10, unused_sum1); - sum11 = hn::ReorderWidenMulAccumulate(df, q1, k1_vec, sum11, unused_sum1); - VBF q2 = hn::BitCast(dbf, hn::Set(df, q_base[2][i])); - sum20 = hn::ReorderWidenMulAccumulate(df, q2, k0_vec, sum20, unused_sum1); - sum21 = hn::ReorderWidenMulAccumulate(df, q2, k1_vec, sum21, unused_sum1); - VBF q3 = hn::BitCast(dbf, hn::Set(df, q_base[3][i])); - sum30 = hn::ReorderWidenMulAccumulate(df, q3, k0_vec, sum30, unused_sum1); - sum31 = hn::ReorderWidenMulAccumulate(df, q3, k1_vec, sum31, unused_sum1); - } - if constexpr (kVTileSize >= 8) { - VBF q4 = hn::BitCast(dbf, hn::Set(df, q_base[4][i])); - sum40 = hn::ReorderWidenMulAccumulate(df, q4, k0_vec, sum40, unused_sum1); - sum41 = hn::ReorderWidenMulAccumulate(df, q4, k1_vec, sum41, unused_sum1); - VBF q5 = hn::BitCast(dbf, hn::Set(df, q_base[5][i])); - sum50 = hn::ReorderWidenMulAccumulate(df, q5, k0_vec, sum50, unused_sum1); - sum51 = hn::ReorderWidenMulAccumulate(df, q5, k1_vec, sum51, unused_sum1); - VBF q6 = hn::BitCast(dbf, hn::Set(df, q_base[6][i])); - sum60 = hn::ReorderWidenMulAccumulate(df, q6, k0_vec, sum60, unused_sum1); - sum61 = hn::ReorderWidenMulAccumulate(df, q6, k1_vec, sum61, unused_sum1); - VBF q7 = hn::BitCast(dbf, hn::Set(df, q_base[7][i])); - sum70 = hn::ReorderWidenMulAccumulate(df, q7, k0_vec, sum70, unused_sum1); - sum71 = hn::ReorderWidenMulAccumulate(df, q7, k1_vec, sum71, unused_sum1); - } +// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. +// This is the result of 4 rows of Q against NF K timesteps, with positions +// given by k_offsets[0..NF]. +template > +void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q, + const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, + const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + sum0 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + const float* HWY_RESTRICT k_base = k.Row(0); + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + VI k_offsets_vec = hn::LoadU(di, k_offsets); + for (size_t i = 0; i < k.Cols(); ++i) { + VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); + VF q_0 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[0] + i])); + sum0 = hn::MulAdd(q_0, k_vec, sum0); + VF q_1 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[1] + i])); + sum1 = hn::MulAdd(q_1, k_vec, sum1); + VF q_2 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[2] + i])); + sum2 = hn::MulAdd(q_2, k_vec, sum2); + VF q_3 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[3] + i])); + sum3 = hn::MulAdd(q_3, k_vec, sum3); } } // Handles NF v rows of flash attention for NF q.k dot products from one q row. -// Automatically handles masking for causal attention and different start_pos -// and last_pos values. template > -HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos, - size_t pos, size_t last_pos, - VF& x, float& old_max, +float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, float& old_d) { - if (pos < start_pos) { - size_t mask_size = start_pos - pos; - const VF neg_inf = hn::Neg(hn::Inf(df)); - x = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x); - } - if (pos + hn::Lanes(df) > last_pos) { - size_t mask_size = pos <= last_pos ? last_pos + 1 - pos : 0; - const VF neg_inf = hn::Neg(hn::Inf(df)); - x = hn::IfThenElse(hn::FirstN(df, mask_size), x, neg_inf); - } float m = hn::ReduceMax(df, x); m = std::max(m, old_max); x = hn::Exp(df, hn::Sub(x, hn::Set(df, m))); @@ -503,60 +440,6 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos, return scale; } -// Handles 2NF v rows of flash attention for 2NF q.k dot products from 1 q row. -// Automatically handles masking for causal attention and different start_pos -// and last_pos values. -template > -HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos, - size_t pos, size_t last_pos, - VF& x0, VF& x1, float& old_max, - float& old_d) { - const size_t kNF = hn::Lanes(df); - const VF neg_inf = hn::Neg(hn::Inf(df)); - if (pos < start_pos) { - if (pos + kNF <= start_pos) { - x0 = neg_inf; - size_t mask_size = start_pos - (pos + kNF); - x1 = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x1); - } else { - size_t mask_size = start_pos - pos; - x0 = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x0); - } - } - if (pos + 2 * kNF > last_pos) { - if (pos + kNF > last_pos) { - size_t mask_size = pos <= last_pos ? last_pos + 1 - pos : 0; - x0 = hn::IfThenElse(hn::FirstN(df, mask_size), x0, neg_inf); - x1 = neg_inf; - } else { - size_t mask_size = last_pos + 1 - (pos + kNF); - x1 = hn::IfThenElse(hn::FirstN(df, mask_size), x1, neg_inf); - } - } - VF x_max = hn::Max(x0, x1); - float m = hn::ReduceMax(df, x_max); - m = std::max(m, old_max); - VF m_vec = hn::Set(df, m); - x0 = hn::Exp(df, hn::Sub(x0, m_vec)); - x1 = hn::Exp(df, hn::Sub(x1, m_vec)); - float scale = old_d * std::exp(old_max - m); - VF x_sum = hn::Add(x0, x1); - old_d = hn::ReduceSum(df, x_sum) + scale; - old_max = m; - if (old_d > 0.0f) { - const float one_over_d = 1.0f / old_d; - scale *= one_over_d; - VF one_over_d_vec = hn::Set(df, one_over_d); - x0 = hn::Mul(x0, one_over_d_vec); - x1 = hn::Mul(x1, one_over_d_vec); - } else { - scale = 0.0f; - x0 = hn::Zero(df); - x1 = hn::Zero(df); - } - return scale; -} - // Reduces each of x and stores in following lanes of max (tested with float32) template , class DF4 = hn::CappedTag, class VF4 = hn::Vec, @@ -908,6 +791,136 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( } } +// Implements flash attention for a strip of 4 query vectors. +// It iterates through timesteps in K from `start_pos` up to `max_last_pos`. +// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows +// by NF timesteps in K for efficiency while timesteps between `min_last_pos + +// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos` +// values within the strip. +// (*) Actually, it only iterates through +// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled +// computation can, for obvious reasons, only process an integer number of +// tiles. +// +// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format. +// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query +// vectors to be processed in this tile. +// @param k Key matrix [seq_len, qkv_dim] from KV cache. +// @param start_pos The first token position in the KV cache to attend to. +// @param last_pos An array of 4 indices giving the last token position +// (inclusive) that each of the 4 queries may attend to. +// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this +// position can be processed efficiently in batches. +// @param max_last_pos The maximum value in `last_pos`. Timesteps between +// `min_last_pos + 1` and this position are processed individually to +// respect each query's `last_pos` limit. +// @param v Value matrix [seq_len, qkv_dim] from KV cache. +// @param layer_idx The index of the current transformer layer. +// @param activations Attention configurations and buffers. +// @param att_out Output buffer for attention results. +// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output +// vectors. +// @param ctx Threading context. +// @param worker Worker thread index. +Tile4FlashState TileFlashAttention4( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + constexpr size_t kMaxNF = hn::MaxLanes(df); + const size_t kHTileSize = hn::Lanes(df); + HWY_DASSERT(kHTileSize <= kMaxNF); + constexpr size_t kVTileSize = 4; + float scales[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); + } + Tile4FlashState state; + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + int32_t k_offsets[kMaxNF]; + size_t v_pos[kMaxNF]; + for (size_t i = 0; i < kHTileSize; ++i) { + v_pos[i] = activations.div_seq_len.Remainder(position + i); + k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); + } + VF x0, x1, x2, x3; + QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + } + scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max, + state.row_states[0].d); + scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max, + state.row_states[1].d); + scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max, + state.row_states[2].d); + scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max, + state.row_states[3].d); + MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), + out_offsets, v.Cols()); + position += kHTileSize; + } + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + if (position <= last_pos[0]) { + // Past the last position, x0 doesn't count. + float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x0, activations.config.att_cap, + state.row_states[0].max, state.row_states[0].d, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[0]); + } + if (position <= last_pos[1]) { + // Past the last position, x1 doesn't count. + float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x1, activations.config.att_cap, + state.row_states[1].max, state.row_states[1].d, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[1]); + } + if (position <= last_pos[2]) { + // Past the last position, x2 doesn't count. + float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x2, activations.config.att_cap, + state.row_states[2].max, state.row_states[2].d, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[2]); + } + if (position <= last_pos[3]) { + // Past the last position, x3 doesn't count. + float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x3, activations.config.att_cap, + state.row_states[3].max, state.row_states[3].d, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[3]); + } + ++position; + } + return state; +} + template , typename T> static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( @@ -1522,581 +1535,29 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( }); } -// Implements flash attention for a strip of tiles of size 1, 4 or 8 query -// vectors by 2NF positions in K. -// It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to -// `params.max_last_pos` (rounded up to the nearest multiple of 2NF). -// Masking allows each row within a tile to have a different start and end -// position. -// -// @param params Tile148Params containing the extent of the strip and -// size of the tiles. -// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format. -// @param k Key matrix from KV cache. K is always pre-transposed to shape: -// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], -// where the /2, *2 represents that pairs of qkv_dim elements are kept -// together to make best use of BF16 dot product instructions. -// @param v Value matrix [seq_len, qkv_dim] from KV cache. -// @param layer_idx The index of the current transformer layer. -// @param activations Attention configurations and buffers. -// @param att_out Output buffer for attention results. -// @param ctx Threading context. -// @param worker Worker thread index. -template -Tile4FlashState TileFlashAttention148( - const Tile148Params& params, const MatPtrT& q, - const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, MatPtrT& att_out, - size_t qkv_dim, ThreadingContext& ctx, const size_t worker, - AttentionImpl attention_impl) { - constexpr Zones kZone = - kVTileSize == 8 - ? Zones::kFlashAttentionTileFlashAttention8 - : (kVTileSize == 4 ? Zones::kFlashAttentionTileFlashAttention4 - : Zones::kFlashAttentionTileFlashAttention1); - GCPP_ZONE(ctx, worker, kZone); - using DF = hn::ScalableTag; - const DF df; - using VF = hn::Vec; - float att_cap = activations.config.att_cap; - float one_over_cap = att_cap > 0.0f ? 1.0f / att_cap : 0.0f; - const size_t kHTileSize = 2 * hn::Lanes(df); - float scales[kVTileSize]; - for (size_t i = 0; i < kVTileSize; ++i) { - hwy::ZeroBytes(att_out.Row(0) + params.out_offsets[i], - qkv_dim * sizeof(att_out.Row(0)[0])); - } - Tile4FlashState state; - size_t position = params.min_start_pos / kHTileSize * kHTileSize; - while (position <= params.max_last_pos) { - // Each pair of vectors covers 2NF positions in K, with up to 8 pairs of - // vectors covering 1, 4 or 8 queries. - VF x00, x01; - VF HWY_MAYBE_UNUSED x10, x11; - VF HWY_MAYBE_UNUSED x20, x21; - VF HWY_MAYBE_UNUSED x30, x31; - VF HWY_MAYBE_UNUSED x40, x41; - VF HWY_MAYBE_UNUSED x50, x51; - VF HWY_MAYBE_UNUSED x60, x61; - VF HWY_MAYBE_UNUSED x70, x71; - constexpr size_t kMaxNF = hn::MaxLanes(df); - size_t v_pos[2 * kMaxNF]; - for (size_t i = 0; i < kHTileSize; ++i) { - v_pos[i] = activations.div_seq_len.Remainder(position + i); - } - if constexpr (IsF32()) { - if constexpr (HWY_NATIVE_DOT_BF16) { - QDotKTile148FloatNative(df, q.Row(0), params.out_offsets, - qkv_dim / 2, k, position, x00, x01, - x10, x11, x20, x21, x30, x31, x40, - x41, x50, x51, x60, x61, x70, x71); - } else { - QDotKTile148FloatNotNative( - df, q.Row(0), params.out_offsets, qkv_dim / 2, k, position, x00, - x01, x10, x11, x20, x21, x30, x31, x40, x41, x50, x51, x60, x61, - x70, x71); - } - } else { - if constexpr (HWY_NATIVE_DOT_BF16) { - QDotKTile148BF16Native(df, q.Row(0), params.q_offsets, - qkv_dim / 2, k, position, x00, x01, - x10, x11, x20, x21, x30, x31, x40, - x41, x50, x51, x60, x61, x70, x71); - } else { - QDotKTile148BF16NotNative( - df, q.Row(0), params.q_offsets, qkv_dim / 2, k, position, x00, x01, - x10, x11, x20, x21, x30, x31, x40, x41, x50, x51, x60, x61, x70, - x71); - } - } - if (att_cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. - ApplySoftCap(df, att_cap, one_over_cap, x00, x10, x20, x30, - x40, x50, x60, x70); - ApplySoftCap(df, att_cap, one_over_cap, x01, x11, x21, x31, - x41, x51, x61, x71); - } - scales[0] = DoubleFlashAttentionRowVector( - df, params.start_pos[0], position, params.last_pos[0], x00, x01, - state.row_states[0].max, state.row_states[0].d); - if constexpr (kVTileSize >= 4) { - scales[1] = DoubleFlashAttentionRowVector( - df, params.start_pos[1], position, params.last_pos[1], x10, x11, - state.row_states[1].max, state.row_states[1].d); - scales[2] = DoubleFlashAttentionRowVector( - df, params.start_pos[2], position, params.last_pos[2], x20, x21, - state.row_states[2].max, state.row_states[2].d); - scales[3] = DoubleFlashAttentionRowVector( - df, params.start_pos[3], position, params.last_pos[3], x30, x31, - state.row_states[3].max, state.row_states[3].d); - MulByConstAndAddVT4Mem(df, scales, x00, x01, x10, x11, x20, x21, x30, x31, - v, v_pos, params.max_last_pos + 1 - position, - att_out.Row(0), params.out_offsets, qkv_dim); - } else { - MulByConstAndAddVT1Mem(df, scales, x00, x01, v, v_pos, - params.max_last_pos + 1 - position, att_out.Row(0), - params.out_offsets, qkv_dim); - } - if constexpr (kVTileSize >= 8) { - scales[4] = DoubleFlashAttentionRowVector( - df, params.start_pos[4], position, params.last_pos[4], x40, x41, - state.row_states[4].max, state.row_states[4].d); - scales[5] = DoubleFlashAttentionRowVector( - df, params.start_pos[5], position, params.last_pos[5], x50, x51, - state.row_states[5].max, state.row_states[5].d); - scales[6] = DoubleFlashAttentionRowVector( - df, params.start_pos[6], position, params.last_pos[6], x60, x61, - state.row_states[6].max, state.row_states[6].d); - scales[7] = DoubleFlashAttentionRowVector( - df, params.start_pos[7], position, params.last_pos[7], x70, x71, - state.row_states[7].max, state.row_states[7].d); - MulByConstAndAddVT4Mem(df, scales + 4, x40, x41, x50, x51, x60, x61, x70, - x71, v, v_pos, params.max_last_pos + 1 - position, - att_out.Row(0), params.out_offsets + 4, qkv_dim); - } - position += kHTileSize; - } - return state; -} - -HWY_INLINE void DispatchTileFlashAttention148( - Tile148Params& params, const MatPtrT& q, const MatPtrT& k, - const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, MatPtrT& att_out, - size_t qkv_dim, ThreadingContext& ctx, const size_t worker, - AttentionImpl attention_impl) { - if (params.v_tile_size == k8xNFVTileSize) { - params.end_state = TileFlashAttention148( - params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker, - attention_impl); - } else if (params.v_tile_size == k4xNFVTileSize) { - params.end_state = TileFlashAttention148( - params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker, - attention_impl); - } else { - params.end_state = - TileFlashAttention148<1>(params, q, k, v, layer_idx, activations, - att_out, qkv_dim, ctx, worker, attention_impl); - } +// Rounds n to a number that can be used as the number of Q rows in a tile +// of flash attention. +static size_t RoundToSuitablePowerOf2(size_t n) { + if (n < 4) return 1; + if (n < 8) return 4; + if (n < 16) return 8; + if (n < 32) return 16; + return 32; } // The vertical tile size is determined by the ability to use tiling and the // target_parallelism. In practice the possible tile sizes in order of -// preference for efficiency are 8, 4, 1. The final tile size is chosen to be -// the largest possible that allows for target_parallelism parallel tasks. +// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or +// 16. The final tile size is chosen to be the largest possible that allows +// for target_parallelism parallel tasks. size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, size_t total_tasks, size_t target_parallelism) { - const size_t kMaxEqualK = num_head_groups * num_tokens; - if (total_tasks / k8xNFVTileSize >= target_parallelism && - kMaxEqualK >= k8xNFVTileSize && kNF >= k8xNFVTileSize) { - return k8xNFVTileSize; - } - if (total_tasks / k4xNFVTileSize >= target_parallelism && - kMaxEqualK >= k4xNFVTileSize && kNF >= k4xNFVTileSize) { - return k4xNFVTileSize; - } - return 1; -} - -// Clears and fills the params vector with Tile148Params for the given -// num_tokens, target_parallelism, and layer_idx. Computes tile sizes and -// offsets for each tile to achieve target_parallelism. -void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, - size_t layer_idx, AttentionActivationsPtrs& activations, - QBatch& qbatch, AttentionImpl attention_impl, - std::vector& params) { - const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; - const hwy::Divisor div_qbatch(qbatch.Size()); - const size_t qkv_dim = layer_config.qkv_dim; - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); - - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); - const size_t total_tasks = token_batch * layer_config.heads; - size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks, - target_parallelism); - // All layers should have the same number of heads. - HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); - // To maximize adjacent tasks with the same kv matrices, task index is encoded - // thus: [qi][kv_head][batch_idx][head_group]. Note that the head index is - // split into kv_head and head_group, since the head_group does not affect - // the KV matrices, and kv_head does. batch_idx does not affect the KV - // matrices, but does affect the last position in the sequence. qi affects - // everything. - params.clear(); - for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) { - for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) { - const size_t head_offset = kv_head * qkv_dim * 2; - const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset; - params.push_back(Tile148Params{ - .qi_index = qi, - .kv_offset = kv_offset, - }); - for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = qbatch.Pos(qi) + batch_idx; - const size_t start_pos = StartPos(pos, activations.config, layer_idx); - size_t last = pos; - const size_t prefix_end = qbatch.PrefixEnd(qi); - if (prefix_end > 0 && prefix_end - 1 > last) { - // last_pos is inclusive. - last = prefix_end - 1; - } - for (size_t head_group = 0; head_group < kHeadGroups; ++head_group) { - size_t tasks_remaining = kHeadGroups - head_group + - kHeadGroups * (num_tokens - 1 - batch_idx); - // We want to fill a tile of size kVTileSize or k4xNFVTileSize if - // smaller, otherwise everything is singles to the next head group. - size_t tasks_required = params.back().v_tile_size < k4xNFVTileSize - ? k4xNFVTileSize - : kVTileSize; - if (params.back().v_tile_size + tasks_remaining < tasks_required || - params.back().v_tile_size == kVTileSize) { - // We don't have enough tasks remaining to fill a tile, or the - // current tile is full so start new tile. - params.push_back(Tile148Params{ - .qi_index = qi, - .kv_offset = kv_offset, - }); - } - const size_t head = head_group + kHeadGroups * kv_head; - const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; - auto& param = params.back(); - size_t offset = param.v_tile_size; - param.q_offsets[offset] = activations.q_bf.Row(tq_idx) + - head * qkv_dim - activations.q_bf.Row(0); - param.out_offsets[offset] = activations.att_out.Row(tq_idx) + - head * qkv_dim - - activations.att_out.Row(0); - param.tq_idx[offset] = tq_idx; - param.start_pos[offset] = start_pos; - param.min_start_pos = HWY_MIN(param.min_start_pos, start_pos); - param.last_pos[offset] = last; - param.max_last_pos = HWY_MAX(param.max_last_pos, last); - ++param.v_tile_size; - } - } - } - } -} - -// Returns the maximum number of tiles needed for any query in the batch. -size_t GetMaxTiles(const std::vector& params, - const size_t kHTileSize) { - size_t max_tiles = 0; - for (const auto& param : params) { - size_t start = param.min_start_pos / kHTileSize; - size_t last = param.max_last_pos / kHTileSize; - max_tiles = HWY_MAX(last + 1 - start, max_tiles); - } - return max_tiles; -} - -// Splits params into smaller k-strips to allow for more parallelism. -// The strips are of size num_tiles_per_task * kHTileSize. -// split_params is cleared and filled with the split tasks. -void SplitTasksByKPos(std::vector& params, - const size_t kHTileSize, const size_t num_tiles_per_task, - const size_t out_stride, - std::vector& split_params) { - split_params.clear(); - for (auto& param : params) { - param.split_index = split_params.size(); - size_t start = param.min_start_pos / kHTileSize; - size_t last = param.max_last_pos / kHTileSize; - for (size_t tile_pos = start; tile_pos <= last; - tile_pos += num_tiles_per_task) { - auto& split_param = split_params.emplace_back(param); - split_param.i_of_n = (tile_pos - start) / num_tiles_per_task; - split_param.n_of_n = hwy::DivCeil(last - start, num_tiles_per_task); - uint32_t tile_last = (tile_pos + num_tiles_per_task) * kHTileSize - 1; - if (tile_last < param.max_last_pos) { - split_param.max_last_pos = tile_last; - for (auto& last_pos : split_param.last_pos) { - last_pos = std::min(last_pos, tile_last); - } - } - uint32_t tile_start = tile_pos * kHTileSize; - if (tile_start > param.min_start_pos) { - split_param.min_start_pos = tile_start; - for (auto& start_pos : split_param.start_pos) { - start_pos = std::max(start_pos, tile_start); - } - } - if (split_param.i_of_n > 0) { - for (size_t i = 0; i < split_param.v_tile_size; ++i) { - split_param.tq_idx[i] = - param.tq_idx[i] * split_param.n_of_n + split_param.i_of_n - 1; - split_param.out_offsets[i] = - param.out_offsets[i] + - (split_param.tq_idx[i] - param.tq_idx[i]) * out_stride; - } - } - } - } -} - -// Clears and fills activations.flash_params with Tile148Params for the -// given num_tokens, target_parallelism, and layer_idx. Computes tile sizes and -// offsets for each tile to achieve target_parallelism. -// If the parallelism is insufficient for this processor type, and the sequence -// length is sufficient, the tiles are upgraded to k4xNFVTileSize and the tasks -// are split along the k positions to achieve the desired parallelism. -// If splitting was required, returns that factor by which the tiles were -// upgraded, k4xNFVTileSize, otherwise returns 0. -uint32_t ComputeAndSplitFlashParams(const size_t kNF, const size_t num_tokens, - const size_t target_parallelism, - size_t layer_idx, - AttentionActivationsPtrs& activations, - QBatch& qbatch, ThreadingContext& ctx, - AttentionImpl attention_impl) { - ComputeFlashParams(num_tokens, target_parallelism, layer_idx, activations, - qbatch, attention_impl, activations.flash_params); - size_t target_workers = std::min(ctx.pools.MaxWorkers(), target_parallelism); - if (activations.flash_params.size() < target_workers) { - // Insufficient parallelism for this processor type. Try splitting along the - // k positions. - size_t max_tiles = GetMaxTiles(activations.flash_params, 2 * kNF); - size_t desired_tiles_per_task = hwy::DivCeil( - activations.flash_params.size() * max_tiles, target_workers); - // The cost of combining split tasks is significant, so we want a minimum - // number of tiles per task, and we want to use k4xNFVTileSize if possible. - constexpr size_t kMinTilesPerTask = 4; - if (desired_tiles_per_task >= k4xNFVTileSize * kMinTilesPerTask) { - // We can afford to use k4xNFVTileSize vertically, so recompute params. - ComputeFlashParams(num_tokens, - activations.flash_params.size() / k4xNFVTileSize, - layer_idx, activations, qbatch, attention_impl, - activations.flash_params); - desired_tiles_per_task = - hwy::DivCeil(desired_tiles_per_task, k4xNFVTileSize); - SplitTasksByKPos( - activations.flash_params, 2 * kNF, desired_tiles_per_task, - activations.att_out_reps.Stride(), activations.split_flash_params); - return k4xNFVTileSize; - } - } - return 0; -} - -// Combines results from split tasks, processing kNumNF * NF qkv values where -// kNumNF can be 1 4 or 16. This enables the intermediate results to be held in -// registers, which speeds up the combination step significantly. -template -void CombineSplitTasks1416(hwy::Span params, - size_t tile_pos, size_t qkv_offset, - AttentionActivationsPtrs& activations) { - using DF = hn::ScalableTag; - const DF df; - using VF = hn::Vec; - const size_t kNF = hn::Lanes(df); - float overall_m = params[0].end_state.row_states[tile_pos].max; - float overall_d = params[0].end_state.row_states[tile_pos].d; - float* HWY_RESTRICT att_out = - activations.att_out.Row(0) + params[0].out_offsets[tile_pos] + qkv_offset; - VF result_0 = hn::Load(df, att_out); - VF result_1, result_2, result_3, result_4, result_5, result_6, result_7; - VF result_8, result_9, result_10, result_11, result_12, result_13, result_14; - VF result_15; - if constexpr (kNumNF > 1) { - result_1 = hn::Load(df, att_out + kNF); - result_2 = hn::Load(df, att_out + 2 * kNF); - result_3 = hn::Load(df, att_out + 3 * kNF); - } - if constexpr (kNumNF == 16) { - result_4 = hn::Load(df, att_out + 4 * kNF); - result_5 = hn::Load(df, att_out + 5 * kNF); - result_6 = hn::Load(df, att_out + 6 * kNF); - result_7 = hn::Load(df, att_out + 7 * kNF); - result_8 = hn::Load(df, att_out + 8 * kNF); - result_9 = hn::Load(df, att_out + 9 * kNF); - result_10 = hn::Load(df, att_out + 10 * kNF); - result_11 = hn::Load(df, att_out + 11 * kNF); - result_12 = hn::Load(df, att_out + 12 * kNF); - result_13 = hn::Load(df, att_out + 13 * kNF); - result_14 = hn::Load(df, att_out + 14 * kNF); - result_15 = hn::Load(df, att_out + 15 * kNF); - } - for (size_t i = 1; i < params.size() && params[i].i_of_n > 0; ++i) { - float m = params[i].end_state.row_states[tile_pos].max; - float d = params[i].end_state.row_states[tile_pos].d; - float new_m = std::max(overall_m, m); - // Scale factor for existing total given the change in max. - float old_scale = overall_d * std::exp(overall_m - new_m); - // Scale factor for new group to add. - float new_scale = d * std::exp(m - new_m); - float new_d = old_scale + new_scale; - float one_over_d = 1.0f / new_d; - old_scale *= one_over_d; - new_scale *= one_over_d; - overall_m = new_m; - overall_d = new_d; - float* HWY_RESTRICT att_in = activations.att_out_reps.Row(0) + - params[i].out_offsets[tile_pos] + qkv_offset; - VF old_scale_vec = hn::Set(df, old_scale); - VF new_scale_vec = hn::Set(df, new_scale); - result_0 = hn::Mul(result_0, old_scale_vec); - result_0 = hn::MulAdd(hn::Load(df, att_in), new_scale_vec, result_0); - if constexpr (kNumNF > 1) { - result_1 = hn::Mul(result_1, old_scale_vec); - result_2 = hn::Mul(result_2, old_scale_vec); - result_3 = hn::Mul(result_3, old_scale_vec); - result_1 = - hn::MulAdd(hn::Load(df, att_in + kNF), new_scale_vec, result_1); - result_2 = - hn::MulAdd(hn::Load(df, att_in + 2 * kNF), new_scale_vec, result_2); - result_3 = - hn::MulAdd(hn::Load(df, att_in + 3 * kNF), new_scale_vec, result_3); - } - if constexpr (kNumNF == 16) { - result_4 = hn::Mul(result_4, old_scale_vec); - result_5 = hn::Mul(result_5, old_scale_vec); - result_6 = hn::Mul(result_6, old_scale_vec); - result_7 = hn::Mul(result_7, old_scale_vec); - result_8 = hn::Mul(result_8, old_scale_vec); - result_9 = hn::Mul(result_9, old_scale_vec); - result_10 = hn::Mul(result_10, old_scale_vec); - result_11 = hn::Mul(result_11, old_scale_vec); - result_12 = hn::Mul(result_12, old_scale_vec); - result_13 = hn::Mul(result_13, old_scale_vec); - result_14 = hn::Mul(result_14, old_scale_vec); - result_15 = hn::Mul(result_15, old_scale_vec); - result_4 = - hn::MulAdd(hn::Load(df, att_in + 4 * kNF), new_scale_vec, result_4); - result_5 = - hn::MulAdd(hn::Load(df, att_in + 5 * kNF), new_scale_vec, result_5); - result_6 = - hn::MulAdd(hn::Load(df, att_in + 6 * kNF), new_scale_vec, result_6); - result_7 = - hn::MulAdd(hn::Load(df, att_in + 7 * kNF), new_scale_vec, result_7); - result_8 = - hn::MulAdd(hn::Load(df, att_in + 8 * kNF), new_scale_vec, result_8); - result_9 = - hn::MulAdd(hn::Load(df, att_in + 9 * kNF), new_scale_vec, result_9); - result_10 = - hn::MulAdd(hn::Load(df, att_in + 10 * kNF), new_scale_vec, result_10); - result_11 = - hn::MulAdd(hn::Load(df, att_in + 11 * kNF), new_scale_vec, result_11); - result_12 = - hn::MulAdd(hn::Load(df, att_in + 12 * kNF), new_scale_vec, result_12); - result_13 = - hn::MulAdd(hn::Load(df, att_in + 13 * kNF), new_scale_vec, result_13); - result_14 = - hn::MulAdd(hn::Load(df, att_in + 14 * kNF), new_scale_vec, result_14); - result_15 = - hn::MulAdd(hn::Load(df, att_in + 15 * kNF), new_scale_vec, result_15); - } - } - hn::Store(result_0, df, att_out); - if constexpr (kNumNF > 1) { - hn::Store(result_1, df, att_out + kNF); - hn::Store(result_2, df, att_out + 2 * kNF); - hn::Store(result_3, df, att_out + 3 * kNF); - } - if constexpr (kNumNF == 16) { - hn::Store(result_4, df, att_out + 4 * kNF); - hn::Store(result_5, df, att_out + 5 * kNF); - hn::Store(result_6, df, att_out + 6 * kNF); - hn::Store(result_7, df, att_out + 7 * kNF); - hn::Store(result_8, df, att_out + 8 * kNF); - hn::Store(result_9, df, att_out + 9 * kNF); - hn::Store(result_10, df, att_out + 10 * kNF); - hn::Store(result_11, df, att_out + 11 * kNF); - hn::Store(result_12, df, att_out + 12 * kNF); - hn::Store(result_13, df, att_out + 13 * kNF); - hn::Store(result_14, df, att_out + 14 * kNF); - hn::Store(result_15, df, att_out + 15 * kNF); - } -} - -void CombineSplitTasksScalar(hwy::Span params, - size_t tile_pos, size_t qkv_offset, - AttentionActivationsPtrs& activations) { - float overall_m = params[0].end_state.row_states[tile_pos].max; - float overall_d = params[0].end_state.row_states[tile_pos].d; - float* HWY_RESTRICT att_out = - activations.att_out.Row(0) + params[0].out_offsets[tile_pos] + qkv_offset; - float result = att_out[0]; - for (size_t i = 1; i < params.size() && params[i].i_of_n > 0; ++i) { - float m = params[i].end_state.row_states[tile_pos].max; - float d = params[i].end_state.row_states[tile_pos].d; - float new_m = std::max(overall_m, m); - // Scale factor for existing total given the change in max. - float old_scale = overall_d * std::exp(overall_m - new_m); - // Scale factor for new group to add. - float new_scale = d * std::exp(m - new_m); - float new_d = old_scale + new_scale; - float one_over_d = 1.0f / new_d; - old_scale *= one_over_d; - new_scale *= one_over_d; - overall_m = new_m; - overall_d = new_d; - float* HWY_RESTRICT att_in = activations.att_out_reps.Row(0) + - params[i].out_offsets[tile_pos] + qkv_offset; - result *= old_scale; - result += att_in[0] * new_scale; - } - att_out[0] = result; -} - -// Recombines results from split tasks, activations.att_out_reps -> -// activations.att_out. Instead of repeatedly calling MultiplyByConstAndAdd, -// which reads/writes the sum each time, the result is kept entirely in -// registers, and the task is split into 16NF, 4NF, and NF chunks, so that there -// are enough registers to hold the intermediate results. -void CombineSplitTasks(size_t qkv_dim, uint32_t tile_factor, - AttentionActivationsPtrs& activations, - ThreadingContext& ctx) { - GCPP_ZONE(ctx, 0, Zones::kFlashAttentionCombineSplit); - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); - uint32_t num_16 = qkv_dim / (16 * kNF); - uint32_t num_4 = (qkv_dim - kNF * 16 * num_16) / (4 * kNF); - uint32_t num_1 = (qkv_dim - kNF * (16 * num_16 + 4 * num_4)) / kNF; - uint32_t num_0 = qkv_dim % kNF; - uint32_t tasks_per_qkv = num_16 + num_4 + num_1 + num_0; - ParallelFor( - Parallelism::kFlat, - activations.flash_params.size() * tasks_per_qkv * tile_factor, ctx, - /*cluster_idx=*/0, Callers::kFlashAttention, - [&](size_t p, size_t worker) { - uint32_t tile = p / tasks_per_qkv; - uint32_t p_idx = - activations.flash_params[tile / tile_factor].split_index; - const auto& param = activations.split_flash_params[p_idx]; - size_t remaining_params = activations.split_flash_params.size() - p_idx; - tile %= tile_factor; - if (tile >= param.v_tile_size) return; - int32_t qkv_task = p % tasks_per_qkv; - if (qkv_task < num_16) { - uint32_t qkv_offset = qkv_task * 16 * kNF; - CombineSplitTasks1416<16>( - hwy::Span(¶m, remaining_params), tile, - qkv_offset, activations); - } else if (qkv_task < num_16 + num_4) { - uint32_t qkv_offset = (num_16 * 16 + (qkv_task - num_16) * 4) * kNF; - CombineSplitTasks1416<4>( - hwy::Span(¶m, remaining_params), tile, - qkv_offset, activations); - } else if (qkv_task < num_16 + num_4 + num_1) { - uint32_t qkv_offset = - (num_16 * 16 + num_4 * 4 + (qkv_task - num_16 - num_4)) * kNF; - CombineSplitTasks1416<1>( - hwy::Span(¶m, remaining_params), tile, - qkv_offset, activations); - } else { - uint32_t qkv_offset = (num_16 * 16 + num_4 * 4 + num_1) * kNF + - (qkv_task - num_16 - num_4 - num_1); - CombineSplitTasksScalar( - hwy::Span(¶m, remaining_params), tile, - qkv_offset, activations); - } - }); + const size_t kMaxEqualK = + RoundToSuitablePowerOf2(num_head_groups * num_tokens); + const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; + return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) + ? kNF + : std::min(kMinTileSize, kMaxEqualK); } // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] @@ -2110,28 +1571,49 @@ void CombineSplitTasks(size_t qkv_dim, uint32_t tile_factor, // the one row of O takes L(4D+3) reads and L(D+3) writes. // For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. // -// Flash attention fuses these operations together, and operates on tiles of -// n Q rows x NF K positions, accumulated in n registers, where n is in -// {1, 4, 8} and NF is the number of float lanes in a register, being 16 for -// AVX3. This reduces the number of reads of Q by NF and reads of K by n. The -// softmax is converted to streaming form using the algorithm from: -// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf, -// which eliminates the need to store A to memory. The accumulated Q.KT result -// is passed via the streaming softmax directly to the A.V step. -// To make the dot product computation more efficient, Q, K, and V are stored -// as BF16 and K is transposed to shape: -// [seq_len / NF, layers * kv_heads * qkv_dim/2 * NF * 2], where the /2, *2 -// represents that pairs of qkv_dim elements are kept together to make best -// use of BF16 dot product instructions, where each pair of adjacent BF16 -// values from Q and K are mul-added into a single F32 result. +// Flash attention fuses these operations together, and has 3 operating modes: +// 1. NF rows of the result computed using tiles of registers of shape NFx8. +// 2. 4 rows of the result computed using tiles of registers of shape 4xNF. +// 3. One row (of Q and the result) at a time. +// In all cases the intermediate result (Q.KT) is never stored to memory. +// NF is the number of float lanes in a register, being 16 for AVX3. The softmax +// is converted to streaming form using the algorithm from: +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. +// Q is transposed to Q_T[D,L] to make the dot product computation efficient. +// +// In mode 1: +// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one +// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is +// computed entirely in registers, and a further NF registers to accumulate the +// results of the product of the softmax and V, reduce the number of reads of V +// by NF, and the reads/writes of O by 8. +// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, +// which on AVX3 is an overall reduction by about a factor of 10. +// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn +// prefill, since in other cases, there is either a single K timestep (prefill) +// or a single num_heads set of Q rows (decode). +// +// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile, +// reducing the reads of Q by NF, and the reads of K by 4. The softmax and +// accumulation of the result is done in registers, cutting the reads of V by 4. +// The reads/writes of O are reduced by a factor of NF. +// The overall reduction is limited by the need to use gather to load K. +// Transposing K would be possible, but is complicated by the wraparound. +// Mode 2 can be used in all cases when there are at least 4 attention heads, +// but it may be prefereable to use mode 3 when the batch size is small to +// maximise parallelism. +// +// In mode 3, a single row of Q is computed against a single K timestep at a +// time, using SingleFlashAttention. In this case there is no reduction in the +// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are +// still eliminated. // // A further complication is that real attention is not as simple as documented // in the paper and above. There are multiple query heads, differing KV, and -// different sequence lengths, and the difference between prefill and decode, -// so a lot of the work in FlashAttention is making sure that a collection of q -// rows with the same KV and sequence length are grouped together so that the -// largest possible tiles can be used. This is dealt with by the -// ComputeAndSplitFlashParams() function. +// different sequence lengths, so a lot of the work in FlashAttention is making +// sure that a collection of q rows with the same KV and sequence length are +// grouped together so that mode 1 or 2 can be used, and choosing which of the +// 3 modes to use for best efficiency. void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const size_t layer_idx, const MatPtr& query_norm_scale, AttentionActivationsPtrs& activations, QBatch& qbatch, @@ -2139,16 +1621,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, query_norm_scale, layer_idx, activations, ctx); - const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; - const size_t qkv_dim = layer_config.qkv_dim; - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); - - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); + const hwy::Divisor div_qbatch(qbatch.Size()); // Compress q to q_bf. - // TODO(rays): Move this into RMSNormAndPositionalEncoding(). ParallelFor( Parallelism::kWithinCluster, activations.q.Rows(), ctx, /*cluster_idx=*/0, Callers::kFlashAttention, @@ -2159,40 +1633,168 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, df, activations.q.Row(row), activations.q.Cols(), tls, MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0); }); - int tile_factor = - ComputeAndSplitFlashParams(kNF, num_tokens, target_parallelism, layer_idx, - activations, qbatch, ctx, attention_impl); - auto& params = tile_factor >= 1 ? activations.split_flash_params - : activations.flash_params; - size_t num_tasks = params.size(); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const size_t qkv_dim = layer_config.qkv_dim; + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); + const size_t total_tasks = token_batch * layer_config.heads; + + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + constexpr size_t kMaxNF = hn::MaxLanes(df); + HWY_DASSERT(kNF <= kMaxNF); + const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, + total_tasks, target_parallelism); + // Only transpose Q if we are using tiling. + if (kVTileSize == kNF) { + size_t max_last = 0, min_start = std::numeric_limits::max(); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + size_t pos = qbatch.Pos(qi); + const size_t start = StartPos(pos, activations.config, layer_idx); + pos += num_tokens - 1; + const size_t end = qbatch.PrefixEnd(qi); + if (end > 0 && end - 1 > pos) { + pos = end - 1; + } + max_last = std::max(max_last, pos); + min_start = std::min(min_start, start); + } + if (max_last - min_start + 1 >= kNFx8HTileSize) { + // q has shape [batch, qbatch][head, qkv_dim]. + // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the + // maximum possible number of consecutive columns have the same KV + // matrices. Each thread will process a tile of NF columns of QT so the + // starting column index of QT is just the task index * kVTileSize. + TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + } + } + const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); + const hwy::Divisor div_tokens(num_tokens); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); // For each head/token/query, compute fused flash Q.K, softmax and weighted V. const auto func = [&](const size_t task, size_t worker) HWY_ATTR { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); - auto& param = params[task]; - auto& kT_cache = qbatch.KV(param.qi_index).k_cache; - MatPtrT kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - qkv_dim * 2 * kNF)); - kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride()); - auto& vT_cache = qbatch.KV(param.qi_index).v_cache; - MatPtrT vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - qkv_dim * 2 * kNF)); - vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride()); - MatPtrT& att_out = - param.i_of_n == 0 ? activations.att_out : activations.att_out_reps; - DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx, - activations, att_out, qkv_dim, ctx, worker, - attention_impl); + // Offsets into original Q for each row in the tile. + uint32_t q_offsets[kMaxNF]; + // Offsets into att_out for each row in the tile. + uint32_t out_offsets[kMaxNF]; + // Start positions for each row in the tile. + size_t start_positions[kMaxNF]; + // Last positions for each row in the tile. Inclusive. + uint32_t last_pos[kMaxNF]; + // min and max last positions across all rows in the tile determines when + // TileFlashAttention switches to single vector mode to handle the + // ragged sequence lengths. + size_t min_last_pos = std::numeric_limits::max(); + size_t max_last_pos = 0; + // Indices into the qbatch.KV for each row in the tile. + size_t qi_indices[kMaxNF]; + // Indices into the kv_cache for each row in the tile. + size_t kv_offsets[kMaxNF]; + // first_task is [qbatch, head, token]. + const size_t first_task = task * kVTileSize; + const size_t last_task = first_task + kVTileSize - 1; + bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks; + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + const size_t batch_idx = div_tokens.Remainder(first_task + offset); + const size_t qh = div_tokens.Divide(first_task + offset); + const size_t head = activations.div_heads.Remainder(qh); + const size_t qi = activations.div_heads.Divide(qh); + const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; + qi_indices[offset] = qi; + + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + start_positions[offset] = start_pos; + size_t last = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last) { + // last_pos in `TileFlashAttention` is inclusive. + last = prefix_end - 1; + } + last_pos[offset] = last; + min_last_pos = HWY_MIN(min_last_pos, last); + max_last_pos = HWY_MAX(max_last_pos, last); + q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim - + activations.q_bf.Row(0); + out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - + activations.att_out.Row(0); + const size_t kv_index = head / kHeadGroups; + const size_t head_offset = kv_index * qkv_dim * 2; + kv_offsets[offset] = layer_idx * cache_layer_size + head_offset; + // If any of the parameters in this if statement differ within this task, + // then we can't use TileFlashAttention. TileFlashAttention requires that + // all rows in the tile have the same K and V matrices, and Q starts at + // the same position. The end positions do not have to be the equal. + if (start_positions[offset] != start_positions[0] || + qi_indices[offset] != qi_indices[0] || + kv_offsets[offset] != kv_offsets[0]) { + use_tile_attention = false; + } + } + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim, + kv_cache.Stride()); + if (use_tile_attention) { + // To avoid duplicating the code to setup K and V, the call to + // TileFlashAttention is inside the loop over tasks, even though it + // handles all rows in the task at once. + StridedView qT = + StridedView(activations.q_T.Row(0) + first_task, kVTileSize, + activations.q_T.Stride()); + if (kVTileSize == kNF) { + // We can still use TileFlashAttention even if we didn't transpose Q + // above. The condition used for transposing Q above is more general + // and easier to compute than the condition used within + // TileFlashAttention that min_last_pos - start_positions[offset] < + // kNFx8HTileSize. In this case, qT is never used. Some tasks might + // use qT and some might not, which is why the more general condition + // is used above to catch all cases where qT will be used. + TileFlashAttention(activations.q_bf, q_offsets, qT, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, activations, + activations.att_out, out_offsets, ctx, worker); + } else if (kVTileSize == 4) { + TileFlashAttention4(activations.q_bf, q_offsets, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, activations, + activations.att_out, out_offsets, ctx, worker); + } else { + HWY_UNREACHABLE; + } + break; + } else { + SingleFlashAttention(start_positions[offset], last_pos[offset], + activations.q_bf.Row(0) + q_offsets[offset], k, v, + layer_idx, activations, + activations.att_out.Row(0) + out_offsets[offset], + ctx, worker); + } + } }; { PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_tasks, ctx, Callers::kFlashAttention, func); - } - if (tile_factor >= 1) { - // Run the flash attention correction on the partial outputs. - CombineSplitTasks(qkv_dim, tile_factor, activations, ctx); + HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention, + func); } } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 7d06af9..5529d9f 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -50,6 +50,25 @@ namespace gcpp { float* HWY_RESTRICT att_out, \ ThreadingContext& ctx, size_t worker); \ \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, MatPtrT& att_out, \ + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \ + const size_t worker); \ + \ + void TileFlashAttention( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const StridedView& qT, const MatPtrT& k, \ + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \ + const size_t min_last_pos, const size_t max_last_pos, \ + const MatPtrT& v, const size_t layer_idx, \ + const AttentionActivationsPtrs& activations, MatPtrT& att_out, \ + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \ + const size_t worker); \ + \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t total_tasks, size_t target_parallelism); \ \ @@ -73,6 +92,7 @@ namespace gcpp { hwy::Span last_pos_per_query, const float att_cap, \ MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ float* HWY_RESTRICT max_logits); \ + \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 782e613..bbb63f5 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -62,17 +62,16 @@ namespace HWY_NAMESPACE { using FloatPtr = hwy::AlignedFreeUniquePtr; -template -void SetMat(const size_t offset, MatPtrT& mat) { +void SetMat(const size_t offset, MatPtrT& mat) { const size_t kOuter = mat.Extents().rows; const size_t kInner = mat.Extents().cols; const float i_scale = 1.0f / kInner; const float j_scale = 1.0f / kOuter; for (size_t i = 0; i < kOuter; ++i) { - T* HWY_RESTRICT row = mat.Row(i); + float* HWY_RESTRICT row = mat.Row(i); for (size_t j = 0; j < kInner; ++j) { - row[j] = hwy::ConvertScalarTo( - static_cast((i * kInner * i_scale + (j + offset) * j_scale))); + row[j] = + static_cast((i * kInner * i_scale + (j + offset) * j_scale)); } } } @@ -95,15 +94,14 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { if (rel_abs_delta > 0.0f) { rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); } - EXPECT_LT(rel_abs_delta, 1e-3) + EXPECT_LT(rel_abs_delta, 1e-5) << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," << c << "]=" << b_row[c]; } } } -void TestFlashAttention(size_t target_parallelism, - AttentionImpl attention_impl) { +void TestFlashAttention(size_t target_parallelism) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); constexpr size_t kOuter = 1024; @@ -132,9 +130,9 @@ void TestFlashAttention(size_t target_parallelism, QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); const size_t batch_size = kOuter; std::vector> row_ptrs; - AttentionActivations attention_storage( - config, layer_config, batch_size, kOuter, runtime_config, - ctx.pools.MaxWorkers(), ctx.allocator, row_ptrs); + AttentionActivations attention_storage(config, layer_config, batch_size, + kOuter, runtime_config, ctx.allocator, + row_ptrs); AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; ASSERT_EQ(qkv_dim, kInner); @@ -144,10 +142,7 @@ void TestFlashAttention(size_t target_parallelism, const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t seq_len = static_cast(attention.div_seq_len.GetDivisor()); - MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache); - MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache); auto& kvc = qbatch.KV(0).kv_cache; - const size_t kFloatsPerTile = 2 * FloatsPerVector(); for (size_t h = 0; h < layer_config.heads; ++h) { // Make strided views into the kv cache for // this query and head. @@ -158,17 +153,6 @@ void TestFlashAttention(size_t target_parallelism, v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride()); SetMat(h + layer_config.heads, k); SetMat(h + layer_config.heads * 2, v); - for (size_t p = 0; p < tokens.size(); ++p) { - KV_t* HWY_RESTRICT k_src = k.Row(p); - KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + - head_offset * kFloatsPerTile / 2 + - p % kFloatsPerTile * 2; - KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + - head_offset * kFloatsPerTile / 2 + - p % kFloatsPerTile * kFloatsPerTile; - - TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim); - } } SetMat(1, attention.q); DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention, @@ -183,19 +167,18 @@ void TestFlashAttention(size_t target_parallelism, tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), total_tasks, target_parallelism); - printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n", - target_parallelism, kNF, kVTileSize, - GetAttentionImplName(attention_impl).c_str()); + printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", + target_parallelism, kNF, kVTileSize); FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale, - attention, qbatch, ctx, attention_impl); + attention, qbatch, ctx, AttentionImpl::kFlash); AssertClose(attention.att_out, *saved_att); ctx.profiler.PrintResults(); } void TestAttention() { - TestFlashAttention(8192, AttentionImpl::kFlash); - TestFlashAttention(2048, AttentionImpl::kFlash); - TestFlashAttention(256, AttentionImpl::kFlash); + TestFlashAttention(8192); + TestFlashAttention(2048); + TestFlashAttention(256); } const std::vector exp_denominator_sums_gold = { diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h index 8c446e0..73563fe 100644 --- a/gemma/flash_structs.h +++ b/gemma/flash_structs.h @@ -2,19 +2,11 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ #include -#include #include namespace gcpp { -// The vertical tile size in flash attention when register lanes correspond to -// K-timesteps, and the number of registers is 4 for 4 Q-rows. -static constexpr size_t k4xNFVTileSize = 4; -// The vertical tile size in flash attention when register lanes correspond to -// K-timesteps, and the number of registers is 8 for 8 Q-rows. -static constexpr size_t k8xNFVTileSize = 8; - // State for computing softmax in a streaming ("online") manner, // avoiding large intermediate values by subtracting the running maximum. // For a sequence x_1, ..., x_n: @@ -28,46 +20,10 @@ struct OnlineSoftmaxState { float d = 0.0f; }; -struct Tile4FlashState { - OnlineSoftmaxState row_states[k8xNFVTileSize]; -}; +static constexpr size_t kVTileSize4 = 4; -// Parameters for a strip of tiles of flash attention. For processing a strip -// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF -// k-positions. The total width of the strip might cover the entire sequence, -// or a part of it, depending on whether the strip has been split. -struct Tile148Params { - // Vertical tile size gives the number used in the k8xNFVTileSize arrays. - // It is the number of Q rows in the tile. - uint32_t v_tile_size = 0; - // min start position across all rows in the tile determines the - // mask used for the tile. - uint32_t min_start_pos = std::numeric_limits::max(); - // max last position across all rows in the tile determines the mask - // used for the tile. - uint32_t max_last_pos = 0; - // Index into the qbatch.KV is the same for each row in the tile. - uint32_t qi_index; - // Index into the kv_cache is the same for each row in the tile. - uint32_t kv_offset; - // In the original task, the index to the split tasks of the first split task. - uint32_t split_index = 0; - // The index of the split for running split attention. - uint32_t i_of_n = 0; - // The number of splits for running split attention. - uint32_t n_of_n = 0; - // Offsets into original Q for each row in the tile. - uint32_t q_offsets[k8xNFVTileSize]; - // Offsets into att_out for each row in the tile. - uint32_t out_offsets[k8xNFVTileSize]; - // Start k-positions for each row in the tile. - uint32_t start_pos[k8xNFVTileSize]; - // Last k-positions for each row in the tile. Inclusive. - uint32_t last_pos[k8xNFVTileSize]; - // Row index to att_out. - uint32_t tq_idx[k8xNFVTileSize]; - // Flash attention state for the tile. - Tile4FlashState end_state; +struct Tile4FlashState { + OnlineSoftmaxState row_states[kVTileSize4]; }; } // namespace gcpp diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index f33cd21..d225f52 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -29,11 +29,6 @@ namespace gcpp { -// TODO: rays - Remove this once hwy is updated. -#ifndef HWY_ARCH_MAX_BYTES -#define HWY_ARCH_MAX_BYTES 256 -#endif - // Number of rows for KV cache. Note that both rows and cols are u32, and // the total number of elements can exceed 2^32. static size_t CappedSeqLen(const ModelConfig& config, @@ -48,23 +43,6 @@ static size_t CappedSeqLen(const ModelConfig& config, KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), - // WARNING: the rows and cols of k_cache and v_cache will be modified - // before use! - // The rows will be reduced by a factor of 2xkFloatsPerVector, and the - // cols will be increased by 2xkFloatsPerVector on first use. This is to - // avoid making KVCache another class that has to be duplicated for each - // machine architecture, since kFloatsPerVector is architecture dependent. - // The change is shape is safe only if the padding is kPacked. - k_cache("k", - Extents2D(HWY_MAX(kv_extents.rows, - 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), - kv_extents.cols / 2), - allocator, MatPadding::kPacked), - v_cache("v", - Extents2D(HWY_MAX(kv_extents.rows, - 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), - kv_extents.cols / 2), - allocator, MatPadding::kPacked), allocator_(allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, @@ -79,16 +57,14 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, : allocator_(allocator) { if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs || runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 - || ((runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs - ) && - hwy::IsSame())) { + ) { const size_t num_tiles = hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize); tiled_seq_len = num_tiles * kTileSize; int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; Type kv_cache_type; if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 - || hwy::IsSame()) { + ) { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16); } else { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32); @@ -138,6 +114,9 @@ KVCache KVCache::Copy() { KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); + + CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr); + copy.tiled_seq_len = tiled_seq_len; return copy; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 5fe1f1e..91b6b7f 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -30,7 +30,7 @@ namespace gcpp { -using KV_t = BF16; +using KV_t = float; struct KVCache; // A non-owning view of a KVCache. @@ -40,8 +40,6 @@ struct KVCachePtr { bool IsTiled() const; MatPtrT kv_cache; - MatPtrT k_cache; - MatPtrT v_cache; KVCache* cache = nullptr; }; @@ -125,33 +123,11 @@ struct KVCache { // kv_head_ptrs[...].Rows(). std::vector kv_head_ptrs; MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] - // The format of k_cache indicates that there are pairs of values from - // qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence, - // in groups of qkv_dim/2 elements in groups of kv_heads elements. - // This enables sequential loading of the data when filling 2 vectors with - // NF sequence elements of pairs of BF16 qkv values. The next vector then - // continues reading the rest of qkv. - // [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2] - MatStorageT k_cache; - // v_cache is formatted to allow sequential access to V during scaling and - // update of att_out. - // Originally [seq_len, layers * kv_heads * qkv_dim] - // v_cache is transposed to: - // [layers, kv_heads, seq_len, qkv_dim], reshaped to: - // [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF] - // then transposed to: - // [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF] - // and finally packed in a 2D MatStorageT as: - // [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF] - // This allows sequential reads of 2NF registers each of 2NF BF16 values, - // repeatedly until all of qkv_dim is read. - MatStorageT v_cache; KVCachePtr ToPtr() { return KVCachePtr{ .kv_cache = kv_cache, - .k_cache = k_cache, - .v_cache = v_cache, + .cache = this, }; } diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 46d276f..7f9c8ca 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -98,7 +98,7 @@ struct AttentionTestEnv { } } } else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) { - MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; + MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; FillMatPtrT(compact_kv_cache); } else { FillMatPtrT(kv_caches.back().kv_cache); @@ -735,13 +735,12 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(TiledAttentionTest); HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries); -// TODO() Fix the goldens for the change in KV_t to BF16 -// HWY_EXPORT_AND_TEST_P(TiledAttentionTest, -// TestLocalAttentionForAllHeadsTokensAndBatch); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, + TestLocalAttentionForAllHeadsTokensAndBatch); HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens); HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16); -// HWY_EXPORT_AND_TEST_P(TiledAttentionTest, -// TestAttentionMultipleTokensAttentionWindowSizeEdgeCase); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, + TestAttentionMultipleTokensAttentionWindowSizeEdgeCase); HWY_AFTER_TEST(); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 8ce45ab..affde22 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -613,6 +613,267 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, }); } +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); + sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale)); + sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale)); + sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale)); + sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale)); + sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale)); + sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale)); + sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale)); + sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); + sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8); + sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9); + sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10); + sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11); + sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12); + sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13); + sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14); + sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) { + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) {} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, + VF& sum3) { + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); +} + +// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows +// of V by the corresponding values in c0-c7 and adds them to NF rows of out, +// after first prescaling out by scale. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( + DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, + const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + size_t i = 0; + while (i + NF <= size) { + if HWY_LANES_CONSTEXPR (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } + if HWY_LANES_CONSTEXPR (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } + if HWY_LANES_CONSTEXPR (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd4(df, x1, c1, out0, out1, out2, out3); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd4(df, x2, c2, out0, out1, out2, out3); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd4(df, x3, c3, out0, out1, out2, out3); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd4(df, x4, c4, out0, out1, out2, out3); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd4(df, x5, c5, out0, out1, out2, out3); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd4(df, x6, c6, out0, out1, out2, out3); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd4(df, x7, c7, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } + i += NF; + } + HWY_DASSERT(size == i); +} + template > HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, const VF c1, const VF c2, const VF c3, @@ -625,147 +886,145 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, } template > -HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4( - DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c, - const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a, - VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) { - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) { - VBF v0 = hn::Load(dbf, v); - VF c0 = hn::Set(df, *c++); - VF c1 = hn::Set(df, *c++); - VF c2 = hn::Set(df, *c++); - VF c3 = hn::Set(df, *c++); - VF v0a = hn::PromoteLowerTo(df, v0); - VF v0b = hn::PromoteUpperTo(df, v0); - MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a); - MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b); - } +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, + const VF c1, const VF c2, + const VF c3, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + // TODO(rays): Check whether a transpose of c0-c3 is applicable and faster. + VF x0 = hn::Load(df, v.Row(pos[0]) + offset); + MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1), + hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2, + sum3); + VF x1 = hn::Load(df, v.Row(pos[1]) + offset); + MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1), + hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2, + sum3); + VF x2 = hn::Load(df, v.Row(pos[2]) + offset); + MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1), + hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2, + sum3); + VF x3 = hn::Load(df, v.Row(pos[3]) + offset); + MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1), + hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2, + sum3); } -// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows -// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out, +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x4 = hn::Load(df, v.Row(pos[4]) + offset); + MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1), + hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2, + sum3); + VF x5 = hn::Load(df, v.Row(pos[5]) + offset); + MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1), + hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2, + sum3); + VF x6 = hn::Load(df, v.Row(pos[6]) + offset); + MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1), + hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2, + sum3); + VF x7 = hn::Load(df, v.Row(pos[7]) + offset); + MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1), + hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2, + sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x8 = hn::Load(df, v.Row(pos[8]) + offset); + MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1), + hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2, + sum3); + VF x9 = hn::Load(df, v.Row(pos[9]) + offset); + MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1), + hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2, + sum3); + VF x10 = hn::Load(df, v.Row(pos[10]) + offset); + MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1), + hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1, + sum2, sum3); + VF x11 = hn::Load(df, v.Row(pos[11]) + offset); + MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1), + hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1, + sum2, sum3); + VF x12 = hn::Load(df, v.Row(pos[12]) + offset); + MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1), + hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1, + sum2, sum3); + VF x13 = hn::Load(df, v.Row(pos[13]) + offset); + MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1), + hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1, + sum2, sum3); + VF x14 = hn::Load(df, v.Row(pos[14]) + offset); + MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1), + hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1, + sum2, sum3); + VF x15 = hn::Load(df, v.Row(pos[15]) + offset); + MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1), + hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1, + sum2, sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows +// of V by the corresponding values in c0-c3 and adds them to NF rows of out, // after first prescaling out by scale. -// The depth (size) must be a multiple of 2NF. +// The depth (size) must be a multiple of NF. template > -HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem( - DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01, - const VF c10, const VF c11, const VF c20, const VF c21, const VF c30, - const VF c31, const MatPtrT& v, const size_t* HWY_RESTRICT pos, - size_t num_lanes, float* HWY_RESTRICT out, +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( + DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, + const VF c2, const VF c3, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); - constexpr size_t kMaxNF = hn::MaxLanes(df); - const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF)); - HWY_DASSERT(pos[0] % (2 * NF) == 0); - HWY_ALIGN float c_mem[8 * kMaxNF]; - hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem); - hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF); size_t i = 0; - while (i + NF * 2 <= size) { - VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b; - out0a = hn::Load(df, out + i + out_offsets[0]); - out1a = hn::Load(df, out + i + out_offsets[1]); - out2a = hn::Load(df, out + i + out_offsets[2]); - out3a = hn::Load(df, out + i + out_offsets[3]); - VF scale0 = hn::Set(df, scales[0]); - VF scale1 = hn::Set(df, scales[1]); - VF scale2 = hn::Set(df, scales[2]); - VF scale3 = hn::Set(df, scales[3]); - out0a = hn::Mul(out0a, scale0); - out1a = hn::Mul(out1a, scale1); - out2a = hn::Mul(out2a, scale2); - out3a = hn::Mul(out3a, scale3); - out0b = hn::Load(df, out + i + NF + out_offsets[0]); - out1b = hn::Load(df, out + i + NF + out_offsets[1]); - out2b = hn::Load(df, out + i + NF + out_offsets[2]); - out3b = hn::Load(df, out + i + NF + out_offsets[3]); - out0b = hn::Mul(out0b, scale0); - out1b = hn::Mul(out1b, scale1); - out2b = hn::Mul(out2b, scale2); - out3b = hn::Mul(out3b, scale3); - MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a, - out2a, out3a, out0b, out1b, out2b, out3b); - hn::Store(out0a, df, out + i + out_offsets[0]); - hn::Store(out1a, df, out + i + out_offsets[1]); - hn::Store(out2a, df, out + i + out_offsets[2]); - hn::Store(out3a, df, out + i + out_offsets[3]); - hn::Store(out0b, df, out + i + NF + out_offsets[0]); - hn::Store(out1b, df, out + i + NF + out_offsets[1]); - hn::Store(out2b, df, out + i + NF + out_offsets[2]); - hn::Store(out3b, df, out + i + NF + out_offsets[3]); - i += NF * 2; - v_bf += 4 * NF * NF; + while (i + NF <= size) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scales[0])); + out1 = hn::Mul(out1, hn::Set(df, scales[1])); + out2 = hn::Mul(out2, hn::Set(df, scales[2])); + out3 = hn::Mul(out3, hn::Set(df, scales[3])); + MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 8) { + MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 16) { + MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, + out3); + } + } + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + i += NF; } HWY_DASSERT(size == i); } - -template > -HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df, - const BF16* HWY_RESTRICT v, - const float* HWY_RESTRICT c, - const size_t num_lanes, - VF& sum0a, VF& sum0b) { - using DBF = hn::ScalableTag; - const DBF dbf; - using VBF = hn::Vec; - const size_t kNF = hn::Lanes(df); - for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) { - VBF v0 = hn::Load(dbf, v); - VF c0 = hn::Set(df, *c++); - VF v0a = hn::PromoteLowerTo(df, v0); - VF v0b = hn::PromoteUpperTo(df, v0); - sum0a = hn::MulAdd(v0a, c0, sum0a); - sum0b = hn::MulAdd(v0b, c0, sum0b); - } -} - -template > -HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem( - DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01, - const MatPtrT& v, const size_t* HWY_RESTRICT pos, size_t num_lanes, - float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); - constexpr size_t kMaxNF = hn::MaxLanes(df); - const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF)); - HWY_DASSERT(pos[0] % (2 * NF) == 0); - HWY_ALIGN float c_mem[2 * kMaxNF]; - hn::Store(c00, df, c_mem); - hn::Store(c01, df, c_mem + NF); - - size_t i = 0; - while (i + NF * 2 <= size) { - VF out0a, out0b; - out0a = hn::Load(df, out + i + out_offsets[0]); - VF scale0 = hn::Set(df, scales[0]); - out0a = hn::Mul(out0a, scale0); - out0b = hn::Load(df, out + i + NF + out_offsets[0]); - out0b = hn::Mul(out0b, scale0); - MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b); - hn::Store(out0a, df, out + i + out_offsets[0]); - hn::Store(out0b, df, out + i + NF + out_offsets[0]); - i += NF * 2; - v_bf += 4 * NF * NF; - } - while (i < size) { - float sum = out[i + out_offsets[0]] * scales[0]; - const BF16* HWY_RESTRICT v_local = v_bf; - for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF); - ++lane, v_local += 2 * NF) { - sum += hwy::ConvertScalarTo(*v_local) * c_mem[lane]; - } - ++i; - ++v_bf; - } -} - template > static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT& out, size_t start_col, VF out0_0, VF out0_1, @@ -1211,6 +1470,104 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16( HWY_DASSERT(qkv_dim == i); } +// Prescales NF rows of out by scale, then multiplies 1 row of V by the +// corresponding values in c0 and adds them to the NF rows of out. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( + DF df, const VF scale, const VF c0, const MatPtrT& v, + const size_t pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + size_t i = 0; + while (i + NF <= size) { + if HWY_LANES_CONSTEXPR (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } + if HWY_LANES_CONSTEXPR (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } + if HWY_LANES_CONSTEXPR (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } + i += NF; + } + HWY_DASSERT(size == i); +} + // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. // Computes softmax probabilities for the given logits, normalizing in-place. diff --git a/util/mat.h b/util/mat.h index 5b30778..25f2cb2 100644 --- a/util/mat.h +++ b/util/mat.h @@ -202,16 +202,6 @@ class MatPtr : public IFields { override_rows_ = static_cast(rows); } - // Changes the number of rows and columns without reallocating the memory. - // Increases cols by factor and reduces rows by factor. - // The rows must be divisible by factor and the matrix must be packed. - void ReshapePackedRowsToCols(size_t factor) { - HWY_ASSERT(IsPacked()); - private_rows_ = hwy::DivCeil(private_rows_, factor); - cols_ *= factor; - stride_ *= factor; - } - // Offset by which to advance pointers to the next row. size_t Stride() const { return stride_; } diff --git a/util/test_util.h b/util/test_util.h index 443990f..19342e4 100644 --- a/util/test_util.h +++ b/util/test_util.h @@ -106,8 +106,7 @@ template void FillMatPtrT(MatPtrT& mat) { for (int i = 0; i < mat.Rows(); ++i) { for (int j = 0; j < mat.Cols(); ++j) { - mat.Row(i)[j] = - hwy::ConvertScalarTo(hwy::Unpredictable1() * 0.01f * (i + j + 1)); + mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1); } } } diff --git a/util/zones.cc b/util/zones.cc index 78a5fd8..d1f9b8c 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) { return "FlashAttention.Inclusive"; case Zones::kFlashAttentionRmsNormAndPositionalEncoding: return "FlashAttention.RMSNormAndPositionalEncoding"; - case Zones::kFlashAttentionTileFlashAttention1: - return "FlashAttention.TileFlashAttention1"; + case Zones::kFlashAttentionSingleFlashAttention: + return "FlashAttention.SingleFlashAttention"; + case Zones::kFlashAttentionTileFlashAttention: + return "FlashAttention.TileFlashAttention"; case Zones::kFlashAttentionTileFlashAttention4: return "FlashAttention.TileFlashAttention4"; - case Zones::kFlashAttentionTileFlashAttention8: - return "FlashAttention.TileFlashAttention8"; - case Zones::kFlashAttentionCombineSplit: - return "FlashAttention.CombineSplit"; + case Zones::kFlashAttentionTransposeQ: + return "FlashAttention.TransposeQ"; case Zones::kGenActivation: return "Gen.Activation"; case Zones::kGenActivationFused: diff --git a/util/zones.h b/util/zones.h index 6f1a68c..a6d40fa 100644 --- a/util/zones.h +++ b/util/zones.h @@ -14,10 +14,10 @@ enum class Zones { // Keep sorted kFlashAttentionFlashAttention, kFlashAttentionInclusive, kFlashAttentionRmsNormAndPositionalEncoding, - kFlashAttentionTileFlashAttention1, + kFlashAttentionSingleFlashAttention, + kFlashAttentionTileFlashAttention, kFlashAttentionTileFlashAttention4, - kFlashAttentionTileFlashAttention8, - kFlashAttentionCombineSplit, + kFlashAttentionTransposeQ, kGenActivation, kGenActivationFused, kGenAttention,