diff --git a/BUILD.bazel b/BUILD.bazel index cb6ca50..6979bdc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -547,6 +547,7 @@ cc_library( deps = [ ":basics", ":configs", + ":flash_structs", ":gemma_args", ":kv_cache", ":mat", @@ -594,6 +595,11 @@ cc_test( INTERNAL_DEPS = [] +cc_library( + name = "flash_structs", + hdrs = ["gemma/flash_structs.h"], +) + cc_library( name = "attention", srcs = [ @@ -603,7 +609,6 @@ cc_library( hdrs = [ "gemma/attention.h", "gemma/flash_attention.h", - "gemma/flash_structs.h", ], textual_hdrs = [ "gemma/gemma-inl.h", @@ -612,6 +617,7 @@ cc_library( ":activations", ":basics", ":configs", + ":flash_structs", ":kv_cache", ":mat", ":matmul", diff --git a/gemma/activations.h b/gemma/activations.h index 1e0a56a..dac5a36 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -24,6 +24,7 @@ #include #include "gemma/configs.h" // ModelConfig +#include "gemma/flash_structs.h" #include "gemma/gemma_args.h" // AttentionImpl #include "gemma/kv_cache.h" #include "gemma/tensor_stats.h" @@ -52,10 +53,13 @@ struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, - const Allocator& allocator, + size_t max_workers, const Allocator& allocator, std::vector>& row_ptrs) - : // `vocab_size == 0` means it is for Vit part, VitAttention is still - // MHA and does not use an external KV cache. + : rep_factor(max_workers * + AttentionActivations::kThreadReplicationFactor / + layer_config.heads), + // `vocab_size == 0` means it is for Vit part, VitAttention + // is still MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim @@ -86,6 +90,9 @@ struct AttentionActivations { att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), + att_out_reps(MatFactory("att_out", batch_size * rep_factor, + layer_config.heads * layer_config.qkv_dim, + allocator)), softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads, allocator)), softmax_d( @@ -107,6 +114,11 @@ struct AttentionActivations { } return; } + // This is a guess at the maximum number of params we might need to avoid + // reallocations. The actual number of params is determined by the number of + // query tiles, which is not known here. + flash_params.reserve(batch_size * layer_config.heads); + split_flash_params.reserve(batch_size * layer_config.heads); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but @@ -130,6 +142,7 @@ struct AttentionActivations { pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); + att_out_reps.OverrideRows(batch_size * rep_factor); softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); @@ -137,6 +150,15 @@ struct AttentionActivations { // `inv_timescale*` are not batched. } + // Maximum factor by which we might scale-up work to maximize parallelism. + size_t rep_factor = 1; + // Parameters for flash attention. The size of the vector is somewhere between + // the number of query rows and 1/8th of that. + std::vector flash_params; + // Parameters for flash attention, split by k-position. May be significantly + // larger than flash_params in decode mode, when the number of query rows is + // small. + std::vector split_flash_params; MatStorageT q; // query MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. @@ -148,6 +170,7 @@ struct AttentionActivations { MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output + MatStorageT att_out_reps; // attention output for each thread. MatStorageT softmax_max; // see OnlineSoftmaxState MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads @@ -156,19 +179,27 @@ struct AttentionActivations { // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; + // Replication factor to help evenly share work over threads. + static constexpr size_t kThreadReplicationFactor = 4; }; // A non-owning view of AttentionActivations. struct AttentionActivationsPtrs { - AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) + AttentionActivationsPtrs( + const ModelConfig& config, size_t seq_len, + std::vector& flash_params, + std::vector& split_flash_params) : config(config), + flash_params(flash_params), + split_flash_params(split_flash_params), div_seq_len(static_cast(seq_len)), div_heads(static_cast(config.layer_configs[0].heads)), query_scale(ChooseQueryScale(config)) {} AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, - const AttentionActivations& activations) - : AttentionActivationsPtrs(config, seq_len) { + AttentionActivations& activations) + : AttentionActivationsPtrs(config, seq_len, activations.flash_params, + activations.split_flash_params) { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; @@ -178,6 +209,7 @@ struct AttentionActivationsPtrs { pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; + att_out_reps = activations.att_out_reps; softmax_max = activations.softmax_max; softmax_d = activations.softmax_d; att_sums = activations.att_sums; @@ -208,6 +240,9 @@ struct AttentionActivationsPtrs { } const ModelConfig& config; + // Parameters for flash attention. + std::vector& flash_params; + std::vector& split_flash_params; // For the matrices below, the batch_size dimension is really qbatch.Size() * // token_batch_size, but in all known uses, one of those is 1. Specifically, @@ -233,6 +268,7 @@ struct AttentionActivationsPtrs { // Attention output computed from att * V, size batch_size x (q_heads * // qkv_dim). MatPtrT att_out; + MatPtrT att_out_reps; // The maximum logit value encountered when computing att_out from att, // size batch_size x q_heads . See OnlineSoftmaxState for details. // WARNING: Only filled in for AttentionImpl::kOld. @@ -287,7 +323,8 @@ struct Activations { s_w_linear_w(config.num_layers, max_workers), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, - runtime_config, ctx.allocator, row_ptrs), + runtime_config, ctx.pools.MaxWorkers(), ctx.allocator, + row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); diff --git a/gemma/attention.cc b/gemma/attention.cc index 8ea9b6d..570c4f4 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -49,6 +49,39 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +// Returns the number of floats per vector (aka NF). +size_t FloatsPerVector() { + using DF = hn::ScalableTag; + const DF df; + return hn::Lanes(df); +} + +// The k-cache and v-cache are setup without knowing NF. So if it hasn't been +// done already, reshape it to take NF into account. +void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache) { + if (kv.Cols() > cache.Cols()) { + cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); + } +} + +// Transposes a single row of the kv cache into the k-cache and v-cache. +void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, + KV_t* HWY_RESTRICT v, size_t qkv_dim) { + // This is inefficient, as the writes are scattered over cache lines, but it + // is a tiny fraction of the overall computation, and it is linear in the + // token length. + const size_t kFloatsPerTile = 2 * FloatsPerVector(); + for (size_t i = 0; i < qkv_dim; i += 2) { + k[i * kFloatsPerTile] = kv[i]; + k[i * kFloatsPerTile + 1] = kv[i + 1]; + } + for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) { + for (size_t j = 0; j < kFloatsPerTile; j++) { + v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; + } + } +} + // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -280,6 +313,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache); + MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache); + } + const size_t kFloatsPerVector = FloatsPerVector(); // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the @@ -299,6 +337,26 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + head * qkv_dim * 2; + // Note that k_cache and v_cache are different shapes. + // The innermost dimension of k is 2 values from qkv_dim because they + // are going to be used in a BF16 dot product involving pairs of + // values over NF k positions. + // The innermost dimension of v is 2NF values from qkv_dim because they + // will be loaded into a BF16 vector to be scaled and added to the + // cached attention output in 2 NF-sized registers. + // TODO(rays): factor out these calculations into functions. + auto& k_cache = qbatch.KV(qi).k_cache; + KV_t* HWY_RESTRICT k = + k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + (layer_idx * cache_layer_size + head * qkv_dim * 2) * + kFloatsPerVector + + (cache_pos % (2 * kFloatsPerVector)) * 2; + auto& v_cache = qbatch.KV(qi).v_cache; + KV_t* HWY_RESTRICT v = + v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + (layer_idx * cache_layer_size + head * qkv_dim * 2) * + kFloatsPerVector + + (cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector; HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; const hn::ScalableTag df; @@ -319,6 +377,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); + // This is inefficient, as multiple threads are writing the same K + // cache line, but the input is generated by a matmul, so it is + // difficult to change, and it probably isn't significant. + TransposeKVCacheRow(kv, k, v, qkv_dim); }); } @@ -341,7 +403,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, } else { // * 2 does not help on Turin. FlashAttention(num_tokens, - /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, + /*target_parallelism=*/env.ctx.pools.MaxWorkers() * + AttentionActivations::kThreadReplicationFactor, layer_idx, layer.query_norm_scale, activations, qbatch, env.ctx, attention_impl); } diff --git a/gemma/attention.h b/gemma/attention.h index 7fb958f..14870de 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -31,6 +31,13 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ + size_t FloatsPerVector(); \ + \ + void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache); \ + \ + void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \ + KV_t* HWY_RESTRICT v, size_t qkv_dim); \ + \ void PositionalEncodingQK(float* qk, size_t layer_idx, \ const AttentionActivationsPtrs& activations, \ ThreadingContext& ctx, size_t worker, size_t pos, \ diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc index fe57a5b..46214fd 100644 --- a/gemma/attention_test.cc +++ b/gemma/attention_test.cc @@ -1,8 +1,10 @@ #include +#include #include // strcmp #include #include #include +#include #include #include "gtest/gtest.h" @@ -105,7 +107,8 @@ struct TestAttentionState { tokens(num_tokens), attention_storage_(model_state.config, model_state.layer_config, batch_size, num_tokens, runtime_config, - state.ctx.allocator, row_ptrs_), + state.ctx.pools.MaxWorkers(), state.ctx.allocator, + row_ptrs_), attention(model_state.config, num_tokens, attention_storage_) { for (size_t i = 0; i < qbatch_size; ++i) { kv_caches.emplace_back(model_state.config, inference_args, @@ -143,6 +146,7 @@ struct TestAttentionState { }; double GetTolerance() { + if (IsBF16()) return 1e-2; const char* target_name = hwy::TargetName(HWY_TARGET); if (strncmp(target_name, "AVX2", 4) == 0) { return 2e-2; @@ -155,6 +159,57 @@ double GetTolerance() { } } +// Comparison function for computations that used BF16, whether the result is +// stored in BF16 or F32. +// Compare with absolute tolerance for values with small magnitudes. +// Compare with relative tolerance for values with larger magnitudes. +template +bool CompareArraySimilarBF16(const T* expected, const T* actual, size_t count, + const char* target_name, const char* filename, + int line) { + constexpr double kTolerance = 3e-2; + for (size_t i = 0; i < count; ++i) { + const double exp = hwy::ConvertScalarTo(expected[i]); + const double act = hwy::ConvertScalarTo(actual[i]); + const double l1 = std::abs(act - exp); + // Cannot divide, so check absolute error. + if (std::abs(exp) <= 1.0) { + if (l1 > kTolerance) { + std::string array_values = hwy::detail::FormatMismatchedArrays( + expected, actual, count, kTolerance); + HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E l1 %E tol %E%s\n", + target_name, filename, line, "BF16", i, count, exp, act, l1, + kTolerance, array_values.c_str()); + return false; + } + } else { // relative + const double rel = l1 / exp; + if (rel > kTolerance) { + std::string array_values = hwy::detail::FormatMismatchedArrays( + expected, actual, count, kTolerance); + HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E rel %E tol %E%s\n", + target_name, filename, line, "BF16", i, count, exp, act, rel, + kTolerance, array_values.c_str()); + return false; + } + } + } + return true; +} + +template +bool CompareArraySimilar(const T* expected, const T* actual, size_t count, + const char* target_name, const char* filename, + int line) { + if constexpr (IsBF16()) { + return CompareArraySimilarBF16(expected, actual, count, target_name, + filename, line); + } else { + return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(), + target_name, filename, line); + } +} + template void CompareAttSumsWithGolden( const AttentionActivationsPtrs& attention, @@ -170,9 +225,9 @@ void CompareAttSumsWithGolden( for (size_t j = 0; j < kDims; ++j) { actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]); } - EXPECT_TRUE(hwy::CompareArraySimilar( - golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(), - hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(), + kDims, hwy::TargetName(HWY_TARGET), + __FILE__, __LINE__)) << "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi; } } @@ -200,19 +255,20 @@ void CompareKVCacheWithGolden( for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) { for (size_t qi = 0; qi < kQBatchSize; ++qi) { - const float* cache_row = + const BF16* cache_row = kv_caches[qi].kv_cache.Row(start_offset + token_idx); for (size_t j = 0; j < kDims; ++j) { - actual_k_row[j] = cache_row[kv_offset + j]; - actual_v_row[j] = cache_row[kv_offset + qkv_dim + j]; + actual_k_row[j] = hwy::ConvertScalarTo(cache_row[kv_offset + j]); + actual_v_row[j] = + hwy::ConvertScalarTo(cache_row[kv_offset + qkv_dim + j]); } - EXPECT_TRUE(hwy::CompareArraySimilar( - k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(), + EXPECT_TRUE(CompareArraySimilar( + k_golden[token_idx][qi], actual_k_row.get(), kDims, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "K cache mismatch for token_idx=" << token_idx << " qi=" << qi << " kv_head=" << kv_head; - EXPECT_TRUE(hwy::CompareArraySimilar( - v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(), + EXPECT_TRUE(CompareArraySimilar( + v_golden[token_idx][qi], actual_v_row.get(), kDims, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "V cache mismatch for token_idx=" << token_idx << " qi=" << qi << " kv_head=" << kv_head; @@ -238,8 +294,8 @@ void CompareQVecsWithGolden( for (size_t j = 0; j < kDims; ++j) { actual_q_row[j] = q_row[head_offset + j]; } - EXPECT_TRUE(hwy::CompareArraySimilar( - q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(), + EXPECT_TRUE(CompareArraySimilar( + q_golden[token_idx][qi], actual_q_row.get(), kDims, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) << "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi << " q_head=" << q_head; @@ -267,42 +323,42 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = { 26.875, 63, 3.34375, -67.5, 31.125, -190, 125}, {-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125, -14.9375, -109.5, 76, 9.25, -142, 29.5, -105}}, - {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875, + {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875, 128, 27.25, -161, 19.125, -58, 97.5}, {-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5, 102.5, -184, 52.75, 83.5, -71, 46.5, -52}}, - {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125, - 6.90625, 150, 144, -155, -47.25, -98.5, 3.5625}, - {-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, - -66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, - {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, + {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125, + 6.4625, 150, 144, -155, -47.25, -98.5, 3.5625}, + {-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, + -66.5, -0.62165625, -46.5, -152, -2.9375, -81}}, + {{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75, -46, -22, -19.375, -16.125, -148, 20.875}, - {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, + {-47, -19.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, 10.9375, -52.5, 23.25, 75}}, - {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, + {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25, -34.75, 18, -52, 100, -186, -75.5, 50.75}, - {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, + {7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, 39.25, 65, 47.25, -89.5, -34.25, 137}}, {{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5, - 32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132}, - {143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5, - 13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}}, + 32.5, 53.75, 109, 4.62375, 57.5, -20.5, 132}, + {143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5, + 195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}}, {{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25, -117, 26.375, 39.5, 125, 66, 48.75, 20.75}, - {137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5, + {137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5, 88.5, 48.5, -46.25, -36.75, 101.5}}, - {{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5, + {{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5, 125, -53.5, -14.625, 262}, - {29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172, + {28.075, 6.64375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172, 81, -27.25, -3.03125, 111, -167, 59, 176}}, {{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25, -52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5}, - {40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5, + {40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5, 58.25, 173, -53.5, 25.125, 4.8125}}, {{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5, -17.125, 15.25, 143, -27, 103, 101}, {-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625, - 8.125, -99.5, 13.6875, -11.6875, 33}}, + 8.55, -99.5, 14.6875, -11.6875, 33}}, }; // Layer 0, *K*V Head 0 diff --git a/gemma/configs.h b/gemma/configs.h index 2f15ee8..803a48a 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) { } enum class AttentionImpl { - kOld, - kFlash, + kOld, // Previous Attention implementation + kFlash, // Flash Attention (default) kSentinel, }; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 835cd15..e6b81b1 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -57,43 +58,7 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -static constexpr size_t kNFx8HTileSize = 8; static constexpr float kNegInf = -std::numeric_limits::max() / 64.0f; -// Transposes q into q_t. -// Both are 4D tensors stuffed into a 2-D MatPtrT. -// q has shape [batch, qbatch][head, qkv_dim]. -// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum -// possible consecutive elements have the same KV. -static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, - const size_t qbatch_size, ThreadingContext& ctx) { - // Group floats by the number of floats in a cache line. - const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); - const size_t num_heads = q.Cols() / q_t.Rows(); - const size_t batch_size = q.Rows() / qbatch_size; - const auto func = [&](const size_t task, size_t worker) HWY_ATTR { - GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ); - for (size_t lane = 0; lane < kNF; ++lane) { - size_t q_row = task * kNF + lane; - if (q_row >= q_t.Rows()) break; - BF16* HWY_RESTRICT qt_row = q_t.Row(q_row); - for (size_t qi = 0; qi < qbatch_size; ++qi) { - for (size_t h = 0; h < num_heads; ++h) { - for (size_t b = 0; b < batch_size; ++b) { - qt_row[(qi * num_heads + h) * batch_size + b] = - hwy::ConvertScalarTo( - q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]); - } - } - } - } - }; - { - const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); - // Better than kFlat. - ParallelFor(Parallelism::kHierarchical, num_tasks, ctx, - /*cluster_idx=*/0, Callers::kFlashTransposeQ, func); - } -} // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, @@ -136,292 +101,390 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } } -// Handles a single v row of flash attention for a single q.k dot product. -HWY_INLINE void SingleFlashAttentionStep(float x, float cap, float& old_max, - float& old_d, - const float* HWY_RESTRICT v, - const size_t v_cols, - float* HWY_RESTRICT att_out) { - if (cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. - x = cap * std::tanh(x / cap); - } - float m = std::max(x, old_max); - x = std::exp(x - m); - float scale = old_d * std::exp(old_max - m); - old_d = x + scale; - old_max = m; - float one_over_d = 1.0f / old_d; - scale *= one_over_d; - x *= one_over_d; - MulByConst(scale, att_out, v_cols); - MulByConstAndAdd(x, v, att_out, v_cols); -} - -// Calculates the complete attention outputs for a single row of q. -void SingleFlashAttention(const size_t start_pos, const size_t last_pos, - const BF16* HWY_RESTRICT q, const MatPtrT& k, - const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, - float* HWY_RESTRICT att_out, ThreadingContext& ctx, - const size_t worker) { - GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); - const hn::ScalableTag dbf; - const size_t qkv_dim = k.Cols(); - - const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); - // TODO: Mixed-mode can be further improved for Turin: we can demote right - // before we do the dot product instruction, rather than promote both to f32. - // But some potential accuracy loss there, needs evaluation first. - float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); - if (float cap = activations.config.att_cap; cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. - m = cap * std::tanh(m / cap); - } - float d = 1.0f; - // This is just a copy of the first token. - MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - const size_t pos_mod = activations.div_seq_len.Remainder(pos); - float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); - SingleFlashAttentionStep(x, activations.config.att_cap, m, d, - v.Row(pos_mod), v.Cols(), att_out); - } -} - -// Computes and returns a single vector of NF Q.K dot products, which represents -// the dot products of NF rows of Q for a single K timestep. -template > -VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, - const size_t k_pos, const MatPtrT& q, - const MatPtrT& k) { - const hn::ScalableTag dbf; - const size_t qkv_dim = k.Cols(); - - hn::TFromD results[hn::MaxLanes(df)]; - for (size_t i = 0; i < hn::Lanes(df); ++i) { - results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0, - k.Row(k_pos), qkv_dim); - } - return hn::LoadU(df, results); -} - -// Returns an NF Q rows by 8 K rows tile of Q.K dot products. -// This is the result of NF rows of Q against 8 K timesteps, with positions -// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in -// consecutive elements, and other columns by adding q_stride. -template > -void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const size_t* k_pos, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { - constexpr size_t kHTileSize = kNFx8HTileSize; +// Zeroes out kVTileSize of the given vectors. +template > +HWY_INLINE void ZeroResults(DF df, VF& sum0, VF& HWY_MAYBE_UNUSED sum1, + VF& HWY_MAYBE_UNUSED sum2, + VF& HWY_MAYBE_UNUSED sum3, + VF& HWY_MAYBE_UNUSED sum4, + VF& HWY_MAYBE_UNUSED sum5, + VF& HWY_MAYBE_UNUSED sum6, + VF& HWY_MAYBE_UNUSED sum7) { sum0 = hn::Zero(df); - sum1 = hn::Zero(df); - sum2 = hn::Zero(df); - sum3 = hn::Zero(df); - sum4 = hn::Zero(df); - sum5 = hn::Zero(df); - sum6 = hn::Zero(df); - sum7 = hn::Zero(df); - const float* HWY_RESTRICT k_row[kHTileSize]; - for (size_t i = 0; i < kHTileSize; ++i) { - k_row[i] = k.Row(k_pos[i]); + if constexpr (kVTileSize >= 4) { + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); } - - const hn::Rebind dbfh; - using VBF = hn::Vec; - - for (size_t i = 0; i < k.Cols(); ++i) { - const VBF q_vec_bf = hn::Load(dbfh, q); - const VF q_vec = hn::PromoteTo(df, q_vec_bf); - VF k_0 = hn::Set(df, k_row[0][i]); - sum0 = hn::MulAdd(q_vec, k_0, sum0); - VF k_1 = hn::Set(df, k_row[1][i]); - sum1 = hn::MulAdd(q_vec, k_1, sum1); - VF k_2 = hn::Set(df, k_row[2][i]); - sum2 = hn::MulAdd(q_vec, k_2, sum2); - VF k_3 = hn::Set(df, k_row[3][i]); - sum3 = hn::MulAdd(q_vec, k_3, sum3); - VF k_4 = hn::Set(df, k_row[4][i]); - sum4 = hn::MulAdd(q_vec, k_4, sum4); - VF k_5 = hn::Set(df, k_row[5][i]); - sum5 = hn::MulAdd(q_vec, k_5, sum5); - VF k_6 = hn::Set(df, k_row[6][i]); - sum6 = hn::MulAdd(q_vec, k_6, sum6); - VF k_7 = hn::Set(df, k_row[7][i]); - sum7 = hn::MulAdd(q_vec, k_7, sum7); - q += q_stride; + if constexpr (kVTileSize >= 8) { + sum4 = hn::Zero(df); + sum5 = hn::Zero(df); + sum6 = hn::Zero(df); + sum7 = hn::Zero(df); } } -// Returns the element-wise maximum of 8 vectors, in a single vector. -template > -VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2, - const VF& x3, const VF& x4, const VF& x5, - const VF& x6, const VF& x7) { - VF m0 = hn::Max(x0, x1); - VF m1 = hn::Max(x2, x3); - VF m2 = hn::Max(x4, x5); - VF m3 = hn::Max(x6, x7); - m0 = hn::Max(m0, m1); - m2 = hn::Max(m2, m3); - return hn::Max(m0, m2); +// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. +// K is always pre-transposed to shape: +// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 +// represents that pairs of qkv_dim elements are kept together to make best use +// of BF16 dot product instructions. +// Note that this version assumes that Q is float32, and not transposed, and +// HWY_NATIVE_DOT_BF16 is false. +template > +HWY_INLINE void QDotKTile148FloatNotNative( + DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, + size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, + VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, + VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, + VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, + VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, + VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, + VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, + VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { + ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, + sum70); + ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, + sum71); + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + const float* HWY_RESTRICT q_base[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + q_base[i] = q + q_offsets[i]; + } + const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); + for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { + // TODO(rays): Replace with decompress2. + VBF k0_vec = hn::LoadU(dbf, k_base); + VBF k1_vec = hn::LoadU(dbf, k_base + kNF * 2); + VF k0_even = hn::PromoteEvenTo(df, k0_vec); + VF k0_odd = hn::PromoteOddTo(df, k0_vec); + VF k1_even = hn::PromoteEvenTo(df, k1_vec); + VF k1_odd = hn::PromoteOddTo(df, k1_vec); + VF q0_even = hn::Set(df, q_base[0][i * 2]); + VF q0_odd = hn::Set(df, q_base[0][i * 2 + 1]); + sum00 = hn::MulAdd(q0_even, k0_even, sum00); + sum01 = hn::MulAdd(q0_even, k1_even, sum01); + sum00 = hn::MulAdd(q0_odd, k0_odd, sum00); + sum01 = hn::MulAdd(q0_odd, k1_odd, sum01); + if constexpr (kVTileSize >= 4) { + VF q1_even = hn::Set(df, q_base[1][i * 2]); + VF q1_odd = hn::Set(df, q_base[1][i * 2 + 1]); + sum10 = hn::MulAdd(q1_even, k0_even, sum10); + sum11 = hn::MulAdd(q1_even, k1_even, sum11); + sum10 = hn::MulAdd(q1_odd, k0_odd, sum10); + sum11 = hn::MulAdd(q1_odd, k1_odd, sum11); + VF q2_even = hn::Set(df, q_base[2][i * 2]); + VF q2_odd = hn::Set(df, q_base[2][i * 2 + 1]); + sum20 = hn::MulAdd(q2_even, k0_even, sum20); + sum21 = hn::MulAdd(q2_even, k1_even, sum21); + sum20 = hn::MulAdd(q2_odd, k0_odd, sum20); + sum21 = hn::MulAdd(q2_odd, k1_odd, sum21); + VF q3_even = hn::Set(df, q_base[3][i * 2]); + VF q3_odd = hn::Set(df, q_base[3][i * 2 + 1]); + sum30 = hn::MulAdd(q3_even, k0_even, sum30); + sum31 = hn::MulAdd(q3_even, k1_even, sum31); + sum30 = hn::MulAdd(q3_odd, k0_odd, sum30); + sum31 = hn::MulAdd(q3_odd, k1_odd, sum31); + } + if constexpr (kVTileSize >= 8) { + VF q4_even = hn::Set(df, q_base[4][i * 2]); + VF q4_odd = hn::Set(df, q_base[4][i * 2 + 1]); + sum40 = hn::MulAdd(q4_even, k0_even, sum40); + sum41 = hn::MulAdd(q4_even, k1_even, sum41); + sum40 = hn::MulAdd(q4_odd, k0_odd, sum40); + sum41 = hn::MulAdd(q4_odd, k1_odd, sum41); + VF q5_even = hn::Set(df, q_base[5][i * 2]); + VF q5_odd = hn::Set(df, q_base[5][i * 2 + 1]); + sum50 = hn::MulAdd(q5_even, k0_even, sum50); + sum51 = hn::MulAdd(q5_even, k1_even, sum51); + sum50 = hn::MulAdd(q5_odd, k0_odd, sum50); + sum51 = hn::MulAdd(q5_odd, k1_odd, sum51); + VF q6_even = hn::Set(df, q_base[6][i * 2]); + VF q6_odd = hn::Set(df, q_base[6][i * 2 + 1]); + sum60 = hn::MulAdd(q6_even, k0_even, sum60); + sum61 = hn::MulAdd(q6_even, k1_even, sum61); + sum60 = hn::MulAdd(q6_odd, k0_odd, sum60); + sum61 = hn::MulAdd(q6_odd, k1_odd, sum61); + VF q7_even = hn::Set(df, q_base[7][i * 2]); + VF q7_odd = hn::Set(df, q_base[7][i * 2 + 1]); + sum70 = hn::MulAdd(q7_even, k0_even, sum70); + sum71 = hn::MulAdd(q7_even, k1_even, sum71); + sum70 = hn::MulAdd(q7_odd, k0_odd, sum70); + sum71 = hn::MulAdd(q7_odd, k1_odd, sum71); + } + } } -// Returns the element-wise sum of 8 vectors, in a single vector. -template > -VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, - const VF& x3, const VF& x4, const VF& x5, - const VF& x6, const VF& x7) { - VF sum0 = hn::Add(x0, x1); - VF sum1 = hn::Add(x2, x3); - VF sum2 = hn::Add(x4, x5); - VF sum3 = hn::Add(x6, x7); - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - return hn::Add(sum0, sum2); -} - -// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to -// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, -// max_last_pos]. -void TileFlashAttention( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, const size_t start_pos, - const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, - const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, MatPtrT& att_out, - const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, - const size_t worker) { - GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); - constexpr size_t kHTileSize = kNFx8HTileSize; +// Loads an adjacent pair of floats, converts them to BF16, and broadcasts them +// across a vector of BF16 as alternating odd and even elements. +// hn::ReorderDemote2To(dbf, q_1_float, q_1_float); with q1_float containing +// alternating odd and even floats appears not to do this. +HWY_INLINE hn::Vec> DemoteAndBroadcast2ToBF16( + const float* HWY_RESTRICT base) { using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; - using DI = hn::ScalableTag; - const DI di; - using VI = hn::Vec; - const size_t kVTileSize = hn::Lanes(df); + VF v_even = hn::Set(df, base[0]); + VF v_odd = hn::Set(df, base[1]); + VF interleaved = hn::OddEven(v_odd, v_even); + return hn::OrderedDemote2To(hn::ScalableTag(), interleaved, + interleaved); +} + +// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. +// K is always pre-transposed to shape: +// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 +// represents that pairs of qkv_dim elements are kept together to make best use +// of BF16 dot product instructions. +// Note that this version assumes that Q is float32, and not transposed, and +// HWY_NATIVE_DOT_BF16 is true. +template > +HWY_INLINE void QDotKTile148FloatNative( + DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, + size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, + VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, + VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, + VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, + VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, + VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, + VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, + VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { + ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, + sum70); + ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, + sum71); + VF unused = hn::Zero(df); + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + const float* HWY_RESTRICT q_base[kVTileSize]; for (size_t i = 0; i < kVTileSize; ++i) { - hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], - v.Cols() * sizeof(att_out.Row(0)[0])); + q_base[i] = q + q_offsets[i]; } - VI lasts = hn::LoadU(di, last_pos); - VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); - VF old_d = hn::Zero(df); - const BF16* HWY_RESTRICT qT_row = qT.Row(0); - const size_t qT_stride = qT.Stride(); - size_t position = start_pos; - while (position + kHTileSize - 1 <= min_last_pos) { - size_t k_pos[kHTileSize]; - for (size_t i = 0; i < kHTileSize; ++i) { - k_pos[i] = activations.div_seq_len.Remainder(position + i); + const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); + for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { + VBF kvec0 = hn::LoadU(dbf, k_base); + VBF kvec1 = hn::LoadU(dbf, k_base + kNF * 2); + VBF q0_bf16 = DemoteAndBroadcast2ToBF16(q_base[0] + i * 2); + sum00 = hn::ReorderWidenMulAccumulate(df, q0_bf16, kvec0, sum00, unused); + sum01 = hn::ReorderWidenMulAccumulate(df, q0_bf16, kvec1, sum01, unused); + if constexpr (kVTileSize >= 4) { + VBF q1_bf16 = DemoteAndBroadcast2ToBF16(q_base[1] + i * 2); + sum10 = hn::ReorderWidenMulAccumulate(df, q1_bf16, kvec0, sum10, unused); + sum11 = hn::ReorderWidenMulAccumulate(df, q1_bf16, kvec1, sum11, unused); + VBF q2_bf16 = DemoteAndBroadcast2ToBF16(q_base[2] + i * 2); + sum20 = hn::ReorderWidenMulAccumulate(df, q2_bf16, kvec0, sum20, unused); + sum21 = hn::ReorderWidenMulAccumulate(df, q2_bf16, kvec1, sum21, unused); + VBF q3_bf16 = DemoteAndBroadcast2ToBF16(q_base[3] + i * 2); + sum30 = hn::ReorderWidenMulAccumulate(df, q3_bf16, kvec0, sum30, unused); + sum31 = hn::ReorderWidenMulAccumulate(df, q3_bf16, kvec1, sum31, unused); } - VF x0, x1, x2, x3, x4, x5, x6, x7; - QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7); - if (activations.config.att_cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. - VF cap = hn::Set(df, activations.config.att_cap); - VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); - x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); - x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); - x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); - x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); - x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap))); - x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap))); - x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap))); - x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap))); + if constexpr (kVTileSize >= 8) { + VBF q4_bf16 = DemoteAndBroadcast2ToBF16(q_base[4] + i * 2); + sum40 = hn::ReorderWidenMulAccumulate(df, q4_bf16, kvec0, sum40, unused); + sum41 = hn::ReorderWidenMulAccumulate(df, q4_bf16, kvec1, sum41, unused); + VBF q5_bf16 = DemoteAndBroadcast2ToBF16(q_base[5] + i * 2); + sum50 = hn::ReorderWidenMulAccumulate(df, q5_bf16, kvec0, sum50, unused); + sum51 = hn::ReorderWidenMulAccumulate(df, q5_bf16, kvec1, sum51, unused); + VBF q6_bf16 = DemoteAndBroadcast2ToBF16(q_base[6] + i * 2); + sum60 = hn::ReorderWidenMulAccumulate(df, q6_bf16, kvec0, sum60, unused); + sum61 = hn::ReorderWidenMulAccumulate(df, q6_bf16, kvec1, sum61, unused); + VBF q7_bf16 = DemoteAndBroadcast2ToBF16(q_base[7] + i * 2); + sum70 = hn::ReorderWidenMulAccumulate(df, q7_bf16, kvec0, sum70, unused); + sum71 = hn::ReorderWidenMulAccumulate(df, q7_bf16, kvec1, sum71, unused); } - VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); - m = hn::Max(old_m, m); - x0 = hn::Exp(df, hn::Sub(x0, m)); - x1 = hn::Exp(df, hn::Sub(x1, m)); - x2 = hn::Exp(df, hn::Sub(x2, m)); - x3 = hn::Exp(df, hn::Sub(x3, m)); - x4 = hn::Exp(df, hn::Sub(x4, m)); - x5 = hn::Exp(df, hn::Sub(x5, m)); - x6 = hn::Exp(df, hn::Sub(x6, m)); - x7 = hn::Exp(df, hn::Sub(x7, m)); - VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); - old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); - old_d = hn::Add(scale, old_d); - old_m = m; - VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); - scale = hn::Mul(scale, one_over_d); - x0 = hn::Mul(x0, one_over_d); - x1 = hn::Mul(x1, one_over_d); - x2 = hn::Mul(x2, one_over_d); - x3 = hn::Mul(x3, one_over_d); - x4 = hn::Mul(x4, one_over_d); - x5 = hn::Mul(x5, one_over_d); - x6 = hn::Mul(x6, one_over_d); - x7 = hn::Mul(x7, one_over_d); - MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, - att_out.Row(0), out_offsets, v.Cols()); - position += kHTileSize; - } - while (position <= max_last_pos) { - size_t k_pos = activations.div_seq_len.Remainder(position); - VF x0 = QDotKVector(df, q_offsets, k_pos, q, k); - if (activations.config.att_cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. - VF cap = hn::Set(df, activations.config.att_cap); - VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); - x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); - } - // Past the last position, x0 doesn't count. - auto mask = hn::Gt(hn::Set(di, position), lasts); - VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask), - std::numeric_limits::max() / 2.0f); - x0 = hn::Sub(x0, causal_offset); - VF m = hn::Max(old_m, x0); - x0 = hn::Exp(df, hn::Sub(x0, m)); - VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); - old_m = m; - old_d = hn::Add(scale, x0); - VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); - x0 = hn::Mul(x0, one_over_d); - scale = hn::Mul(scale, one_over_d); - MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, - v.Cols()); - ++position; } } -// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. -// This is the result of 4 rows of Q against NF K timesteps, with positions -// given by k_offsets[0..NF]. -template > -void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q, - const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, - const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, - VF& sum2, VF& sum3) { - sum0 = hn::Zero(df); - sum1 = hn::Zero(df); - sum2 = hn::Zero(df); - sum3 = hn::Zero(df); - const float* HWY_RESTRICT k_base = k.Row(0); - using DI = hn::ScalableTag; - const DI di; - using VI = hn::Vec; - VI k_offsets_vec = hn::LoadU(di, k_offsets); - for (size_t i = 0; i < k.Cols(); ++i) { - VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[0] + i])); - sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[1] + i])); - sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[2] + i])); - sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[3] + i])); - sum3 = hn::MulAdd(q_3, k_vec, sum3); +// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. +// K is always pre-transposed to shape: +// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 +// represents that pairs of qkv_dim elements are kept together to make best use +// of BF16 dot product instructions. +// Note that this is optimized for the case where q and k are bf16, but there is +// no native_bf16 instruction. +template > +HWY_INLINE void QDotKTile148BF16NotNative( + DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, + size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, + VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, + VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, + VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, + VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, + VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, + VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, + VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { + ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, + sum70); + ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, + sum71); + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + const float* HWY_RESTRICT q_base[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + q_base[i] = reinterpret_cast(q + q_offsets[i]); + } + const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); + for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { + VBF kvec0 = hn::LoadU(dbf, k_base); + VBF kvec1 = hn::LoadU(dbf, k_base + kNF * 2); + VBF q0 = hn::BitCast(dbf, hn::Set(df, q_base[0][i])); + VF k0_even = hn::PromoteEvenTo(df, kvec0); + VF k0_odd = hn::PromoteOddTo(df, kvec0); + VF k1_even = hn::PromoteEvenTo(df, kvec1); + VF k1_odd = hn::PromoteOddTo(df, kvec1); + VF q0_even = hn::PromoteEvenTo(df, q0); + sum00 = hn::MulAdd(q0_even, k0_even, sum00); + sum01 = hn::MulAdd(q0_even, k1_even, sum01); + VF q0_odd = hn::PromoteOddTo(df, q0); + sum00 = hn::MulAdd(q0_odd, k0_odd, sum00); + sum01 = hn::MulAdd(q0_odd, k1_odd, sum01); + if constexpr (kVTileSize >= 4) { + VBF q1 = hn::BitCast(dbf, hn::Set(df, q_base[1][i])); + VF q1_even = hn::PromoteEvenTo(df, q1); + sum10 = hn::MulAdd(q1_even, k0_even, sum10); + sum11 = hn::MulAdd(q1_even, k1_even, sum11); + VF q1_odd = hn::PromoteOddTo(df, q1); + sum10 = hn::MulAdd(q1_odd, k0_odd, sum10); + sum11 = hn::MulAdd(q1_odd, k1_odd, sum11); + VBF q2 = hn::BitCast(dbf, hn::Set(df, q_base[2][i])); + VF q2_even = hn::PromoteEvenTo(df, q2); + sum20 = hn::MulAdd(q2_even, k0_even, sum20); + sum21 = hn::MulAdd(q2_even, k1_even, sum21); + VF q2_odd = hn::PromoteOddTo(df, q2); + sum20 = hn::MulAdd(q2_odd, k0_odd, sum20); + sum21 = hn::MulAdd(q2_odd, k1_odd, sum21); + VBF q3 = hn::BitCast(dbf, hn::Set(df, q_base[3][i])); + VF q3_even = hn::PromoteEvenTo(df, q3); + sum30 = hn::MulAdd(q3_even, k0_even, sum30); + sum31 = hn::MulAdd(q3_even, k1_even, sum31); + VF q3_odd = hn::PromoteOddTo(df, q3); + sum30 = hn::MulAdd(q3_odd, k0_odd, sum30); + sum31 = hn::MulAdd(q3_odd, k1_odd, sum31); + } + if constexpr (kVTileSize >= 8) { + VBF q4 = hn::BitCast(dbf, hn::Set(df, q_base[4][i])); + VF q4_even = hn::PromoteEvenTo(df, q4); + sum40 = hn::MulAdd(q4_even, k0_even, sum40); + sum41 = hn::MulAdd(q4_even, k1_even, sum41); + VF q4_odd = hn::PromoteOddTo(df, q4); + sum40 = hn::MulAdd(q4_odd, k0_odd, sum40); + sum41 = hn::MulAdd(q4_odd, k1_odd, sum41); + VBF q5 = hn::BitCast(dbf, hn::Set(df, q_base[5][i])); + VF q5_even = hn::PromoteEvenTo(df, q5); + sum50 = hn::MulAdd(q5_even, k0_even, sum50); + sum51 = hn::MulAdd(q5_even, k1_even, sum51); + VF q5_odd = hn::PromoteOddTo(df, q5); + sum50 = hn::MulAdd(q5_odd, k0_odd, sum50); + sum51 = hn::MulAdd(q5_odd, k1_odd, sum51); + VBF q6 = hn::BitCast(dbf, hn::Set(df, q_base[6][i])); + VF q6_even = hn::PromoteEvenTo(df, q6); + sum60 = hn::MulAdd(q6_even, k0_even, sum60); + sum61 = hn::MulAdd(q6_even, k1_even, sum61); + VF q6_odd = hn::PromoteOddTo(df, q6); + sum60 = hn::MulAdd(q6_odd, k0_odd, sum60); + sum61 = hn::MulAdd(q6_odd, k1_odd, sum61); + VBF q7 = hn::BitCast(dbf, hn::Set(df, q_base[7][i])); + VF q7_even = hn::PromoteEvenTo(df, q7); + sum70 = hn::MulAdd(q7_even, k0_even, sum70); + sum71 = hn::MulAdd(q7_even, k1_even, sum71); + VF q7_odd = hn::PromoteOddTo(df, q7); + sum70 = hn::MulAdd(q7_odd, k0_odd, sum70); + sum71 = hn::MulAdd(q7_odd, k1_odd, sum71); + } + } +} + +// Returns a tile of 1, 4 or 8 Q rows by 2NF K Q.K dot products, in float32. +// K is always pre-transposed to shape: +// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], where the /2, *2 +// represents that pairs of qkv_dim elements are kept together to make best use +// of BF16 dot product instructions. +// Note that this is optimized for the case where q and k are bf16, and there is +// a native_bf16 instruction. +template > +HWY_INLINE void QDotKTile148BF16Native( + DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, + size_t half_cols, const MatPtrT& k, size_t pos, VF& sum00, VF& sum01, + VF& HWY_MAYBE_UNUSED sum10, VF& HWY_MAYBE_UNUSED sum11, + VF& HWY_MAYBE_UNUSED sum20, VF& HWY_MAYBE_UNUSED sum21, + VF& HWY_MAYBE_UNUSED sum30, VF& HWY_MAYBE_UNUSED sum31, + VF& HWY_MAYBE_UNUSED sum40, VF& HWY_MAYBE_UNUSED sum41, + VF& HWY_MAYBE_UNUSED sum50, VF& HWY_MAYBE_UNUSED sum51, + VF& HWY_MAYBE_UNUSED sum60, VF& HWY_MAYBE_UNUSED sum61, + VF& HWY_MAYBE_UNUSED sum70, VF& HWY_MAYBE_UNUSED sum71) { + ZeroResults(df, sum00, sum10, sum20, sum30, sum40, sum50, sum60, + sum70); + ZeroResults(df, sum01, sum11, sum21, sum31, sum41, sum51, sum61, + sum71); + VF unused_sum1 = hn::Zero(df); + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + const float* HWY_RESTRICT q_base[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + q_base[i] = reinterpret_cast(q + q_offsets[i]); + } + const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); + for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { + VBF k0_vec = hn::LoadU(dbf, k_base); + VBF k1_vec = hn::LoadU(dbf, k_base + kNF * 2); + VBF q0 = hn::BitCast(dbf, hn::Set(df, q_base[0][i])); + sum00 = hn::ReorderWidenMulAccumulate(df, q0, k0_vec, sum00, unused_sum1); + sum01 = hn::ReorderWidenMulAccumulate(df, q0, k1_vec, sum01, unused_sum1); + if constexpr (kVTileSize >= 4) { + VBF q1 = hn::BitCast(dbf, hn::Set(df, q_base[1][i])); + sum10 = hn::ReorderWidenMulAccumulate(df, q1, k0_vec, sum10, unused_sum1); + sum11 = hn::ReorderWidenMulAccumulate(df, q1, k1_vec, sum11, unused_sum1); + VBF q2 = hn::BitCast(dbf, hn::Set(df, q_base[2][i])); + sum20 = hn::ReorderWidenMulAccumulate(df, q2, k0_vec, sum20, unused_sum1); + sum21 = hn::ReorderWidenMulAccumulate(df, q2, k1_vec, sum21, unused_sum1); + VBF q3 = hn::BitCast(dbf, hn::Set(df, q_base[3][i])); + sum30 = hn::ReorderWidenMulAccumulate(df, q3, k0_vec, sum30, unused_sum1); + sum31 = hn::ReorderWidenMulAccumulate(df, q3, k1_vec, sum31, unused_sum1); + } + if constexpr (kVTileSize >= 8) { + VBF q4 = hn::BitCast(dbf, hn::Set(df, q_base[4][i])); + sum40 = hn::ReorderWidenMulAccumulate(df, q4, k0_vec, sum40, unused_sum1); + sum41 = hn::ReorderWidenMulAccumulate(df, q4, k1_vec, sum41, unused_sum1); + VBF q5 = hn::BitCast(dbf, hn::Set(df, q_base[5][i])); + sum50 = hn::ReorderWidenMulAccumulate(df, q5, k0_vec, sum50, unused_sum1); + sum51 = hn::ReorderWidenMulAccumulate(df, q5, k1_vec, sum51, unused_sum1); + VBF q6 = hn::BitCast(dbf, hn::Set(df, q_base[6][i])); + sum60 = hn::ReorderWidenMulAccumulate(df, q6, k0_vec, sum60, unused_sum1); + sum61 = hn::ReorderWidenMulAccumulate(df, q6, k1_vec, sum61, unused_sum1); + VBF q7 = hn::BitCast(dbf, hn::Set(df, q_base[7][i])); + sum70 = hn::ReorderWidenMulAccumulate(df, q7, k0_vec, sum70, unused_sum1); + sum71 = hn::ReorderWidenMulAccumulate(df, q7, k1_vec, sum71, unused_sum1); + } } } // Handles NF v rows of flash attention for NF q.k dot products from one q row. +// Automatically handles masking for causal attention and different start_pos +// and last_pos values. template > -float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, +HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos, + size_t pos, size_t last_pos, + VF& x, float& old_max, float& old_d) { + if (pos < start_pos) { + size_t mask_size = start_pos - pos; + const VF neg_inf = hn::Neg(hn::Inf(df)); + x = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x); + } + if (pos + hn::Lanes(df) > last_pos) { + size_t mask_size = pos <= last_pos ? last_pos + 1 - pos : 0; + const VF neg_inf = hn::Neg(hn::Inf(df)); + x = hn::IfThenElse(hn::FirstN(df, mask_size), x, neg_inf); + } float m = hn::ReduceMax(df, x); m = std::max(m, old_max); x = hn::Exp(df, hn::Sub(x, hn::Set(df, m))); @@ -439,6 +502,60 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, return scale; } +// Handles 2NF v rows of flash attention for 2NF q.k dot products from 1 q row. +// Automatically handles masking for causal attention and different start_pos +// and last_pos values. +template > +HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos, + size_t pos, size_t last_pos, + VF& x0, VF& x1, float& old_max, + float& old_d) { + const size_t kNF = hn::Lanes(df); + const VF neg_inf = hn::Neg(hn::Inf(df)); + if (pos < start_pos) { + if (pos + kNF <= start_pos) { + x0 = neg_inf; + size_t mask_size = start_pos - (pos + kNF); + x1 = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x1); + } else { + size_t mask_size = start_pos - pos; + x0 = hn::IfThenElse(hn::FirstN(df, mask_size), neg_inf, x0); + } + } + if (pos + 2 * kNF > last_pos) { + if (pos + kNF > last_pos) { + size_t mask_size = pos <= last_pos ? last_pos + 1 - pos : 0; + x0 = hn::IfThenElse(hn::FirstN(df, mask_size), x0, neg_inf); + x1 = neg_inf; + } else { + size_t mask_size = last_pos + 1 - (pos + kNF); + x1 = hn::IfThenElse(hn::FirstN(df, mask_size), x1, neg_inf); + } + } + VF x_max = hn::Max(x0, x1); + float m = hn::ReduceMax(df, x_max); + m = std::max(m, old_max); + VF m_vec = hn::Set(df, m); + x0 = hn::Exp(df, hn::Sub(x0, m_vec)); + x1 = hn::Exp(df, hn::Sub(x1, m_vec)); + float scale = old_d * std::exp(old_max - m); + VF x_sum = hn::Add(x0, x1); + old_d = hn::ReduceSum(df, x_sum) + scale; + old_max = m; + if (old_d > 0.0f) { + const float one_over_d = 1.0f / old_d; + scale *= one_over_d; + VF one_over_d_vec = hn::Set(df, one_over_d); + x0 = hn::Mul(x0, one_over_d_vec); + x1 = hn::Mul(x1, one_over_d_vec); + } else { + scale = 0.0f; + x0 = hn::Zero(df); + x1 = hn::Zero(df); + } + return scale; +} + // Reduces each of x and stores in following lanes of max (tested with float32) template , class DF4 = hn::CappedTag, class VF4 = hn::Vec, @@ -590,159 +707,522 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( } } -// Implements flash attention for a strip of 4 query vectors. -// It iterates through timesteps in K from `start_pos` up to `max_last_pos`. -// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows -// by NF timesteps in K for efficiency while timesteps between `min_last_pos + -// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos` -// values within the strip. -// (*) Actually, it only iterates through -// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled -// computation can, for obvious reasons, only process an integer number of -// tiles. +// Implements flash attention for a strip of tiles of size 1, 4 or 8 query +// vectors by 2NF positions in K. +// It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to +// `params.max_last_pos` (rounded up to the nearest multiple of 2NF). +// Masking allows each row within a tile to have a different start and end +// position. // +// @param params FlashAttentionParams containing the extent of the strip and +// size of the tiles. // @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format. -// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query -// vectors to be processed in this tile. -// @param k Key matrix [seq_len, qkv_dim] from KV cache. -// @param start_pos The first token position in the KV cache to attend to. -// @param last_pos An array of 4 indices giving the last token position -// (inclusive) that each of the 4 queries may attend to. -// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this -// position can be processed efficiently in batches. -// @param max_last_pos The maximum value in `last_pos`. Timesteps between -// `min_last_pos + 1` and this position are processed individually to -// respect each query's `last_pos` limit. +// @param k Key matrix from KV cache. K is always pre-transposed to shape: +// [seq_len / 2kNF, layers * kv_heads * qkv_dim/2 * 2kNF * 2], +// where the /2, *2 represents that pairs of qkv_dim elements are kept +// together to make best use of BF16 dot product instructions. // @param v Value matrix [seq_len, qkv_dim] from KV cache. // @param layer_idx The index of the current transformer layer. // @param activations Attention configurations and buffers. // @param att_out Output buffer for attention results. -// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output -// vectors. // @param ctx Threading context. // @param worker Worker thread index. -Tile4FlashState TileFlashAttention4( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const MatPtrT& k, const size_t start_pos, - const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, - const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, +template +Tile4FlashState TileFlashAttention148( + const FlashAttentionParams& params, const MatPtrT& q, + const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const AttentionActivationsPtrs& activations, MatPtrT& att_out, - const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, - const size_t worker) { - GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); + size_t qkv_dim, ThreadingContext& ctx, const size_t worker, + AttentionImpl attention_impl) { + constexpr Zones kZone = + kVTileSize == 8 + ? Zones::kFlashAttentionTileFlashAttention8 + : (kVTileSize == 4 ? Zones::kFlashAttentionTileFlashAttention4 + : Zones::kFlashAttentionTileFlashAttention1); + GCPP_ZONE(ctx, worker, kZone); using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; - constexpr size_t kMaxNF = hn::MaxLanes(df); - const size_t kHTileSize = hn::Lanes(df); - HWY_DASSERT(kHTileSize <= kMaxNF); - constexpr size_t kVTileSize = 4; + float att_cap = activations.config.att_cap; + float one_over_cap = att_cap > 0.0f ? 1.0f / att_cap : 0.0f; + const size_t kHTileSize = 2 * hn::Lanes(df); float scales[kVTileSize]; for (size_t i = 0; i < kVTileSize; ++i) { - hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], - v.Cols() * sizeof(att_out.Row(0)[0])); + hwy::ZeroBytes(att_out.Row(0) + params.out_offsets[i], + qkv_dim * sizeof(att_out.Row(0)[0])); } Tile4FlashState state; - size_t position = start_pos; - while (position + kHTileSize - 1 <= min_last_pos) { - int32_t k_offsets[kMaxNF]; - size_t v_pos[kMaxNF]; + size_t position = params.min_start_pos / kHTileSize * kHTileSize; + while (position <= params.max_last_pos) { + // Each pair of vectors covers 2NF positions in K, with up to 8 pairs of + // vectors covering 1, 4 or 8 queries. + VF x00, x01; + VF HWY_MAYBE_UNUSED x10, x11; + VF HWY_MAYBE_UNUSED x20, x21; + VF HWY_MAYBE_UNUSED x30, x31; + VF HWY_MAYBE_UNUSED x40, x41; + VF HWY_MAYBE_UNUSED x50, x51; + VF HWY_MAYBE_UNUSED x60, x61; + VF HWY_MAYBE_UNUSED x70, x71; + constexpr size_t kMaxNF = hn::MaxLanes(df); + size_t v_pos[2 * kMaxNF]; for (size_t i = 0; i < kHTileSize; ++i) { v_pos[i] = activations.div_seq_len.Remainder(position + i); - k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); } - VF x0, x1, x2, x3; - QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3); - if (activations.config.att_cap > 0.0f) { + if constexpr (IsF32()) { + if constexpr (HWY_NATIVE_DOT_BF16) { + QDotKTile148FloatNative(df, q.Row(0), params.out_offsets, + qkv_dim / 2, k, position, x00, x01, + x10, x11, x20, x21, x30, x31, x40, + x41, x50, x51, x60, x61, x70, x71); + } else { + QDotKTile148FloatNotNative( + df, q.Row(0), params.out_offsets, qkv_dim / 2, k, position, x00, + x01, x10, x11, x20, x21, x30, x31, x40, x41, x50, x51, x60, x61, + x70, x71); + } + } else { + if constexpr (HWY_NATIVE_DOT_BF16) { + QDotKTile148BF16Native(df, q.Row(0), params.q_offsets, + qkv_dim / 2, k, position, x00, x01, + x10, x11, x20, x21, x30, x31, x40, + x41, x50, x51, x60, x61, x70, x71); + } else { + QDotKTile148BF16NotNative( + df, q.Row(0), params.q_offsets, qkv_dim / 2, k, position, x00, x01, + x10, x11, x20, x21, x30, x31, x40, x41, x50, x51, x60, x61, x70, + x71); + } + } + if (att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. - VF cap = hn::Set(df, activations.config.att_cap); - VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); - x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); - x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); - x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); - x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + ApplySoftCap(df, att_cap, one_over_cap, x00, x10, x20, x30, + x40, x50, x60, x70); + ApplySoftCap(df, att_cap, one_over_cap, x01, x11, x21, x31, + x41, x51, x61, x71); + } + scales[0] = DoubleFlashAttentionRowVector( + df, params.start_pos[0], position, params.last_pos[0], x00, x01, + state.row_states[0].max, state.row_states[0].d); + if constexpr (kVTileSize >= 4) { + scales[1] = DoubleFlashAttentionRowVector( + df, params.start_pos[1], position, params.last_pos[1], x10, x11, + state.row_states[1].max, state.row_states[1].d); + scales[2] = DoubleFlashAttentionRowVector( + df, params.start_pos[2], position, params.last_pos[2], x20, x21, + state.row_states[2].max, state.row_states[2].d); + scales[3] = DoubleFlashAttentionRowVector( + df, params.start_pos[3], position, params.last_pos[3], x30, x31, + state.row_states[3].max, state.row_states[3].d); + MulByConstAndAddVT4Mem(df, scales, x00, x01, x10, x11, x20, x21, x30, x31, + v, v_pos, params.max_last_pos + 1 - position, + att_out.Row(0), params.out_offsets, qkv_dim); + } else { + MulByConstAndAddVT1Mem(df, scales, x00, x01, v, v_pos, + params.max_last_pos + 1 - position, att_out.Row(0), + params.out_offsets, qkv_dim); + } + if constexpr (kVTileSize >= 8) { + scales[4] = DoubleFlashAttentionRowVector( + df, params.start_pos[4], position, params.last_pos[4], x40, x41, + state.row_states[4].max, state.row_states[4].d); + scales[5] = DoubleFlashAttentionRowVector( + df, params.start_pos[5], position, params.last_pos[5], x50, x51, + state.row_states[5].max, state.row_states[5].d); + scales[6] = DoubleFlashAttentionRowVector( + df, params.start_pos[6], position, params.last_pos[6], x60, x61, + state.row_states[6].max, state.row_states[6].d); + scales[7] = DoubleFlashAttentionRowVector( + df, params.start_pos[7], position, params.last_pos[7], x70, x71, + state.row_states[7].max, state.row_states[7].d); + MulByConstAndAddVT4Mem(df, scales + 4, x40, x41, x50, x51, x60, x61, x70, + x71, v, v_pos, params.max_last_pos + 1 - position, + att_out.Row(0), params.out_offsets + 4, qkv_dim); } - scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max, - state.row_states[0].d); - scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max, - state.row_states[1].d); - scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max, - state.row_states[2].d); - scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max, - state.row_states[3].d); - MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), - out_offsets, v.Cols()); position += kHTileSize; } - const hn::ScalableTag dbf; - const size_t qkv_dim = k.Cols(); - - while (position <= max_last_pos) { - size_t k_pos = activations.div_seq_len.Remainder(position); - if (position <= last_pos[0]) { - // Past the last position, x0 doesn't count. - float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0, - k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x0, activations.config.att_cap, - state.row_states[0].max, state.row_states[0].d, - v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[0]); - } - if (position <= last_pos[1]) { - // Past the last position, x1 doesn't count. - float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0, - k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x1, activations.config.att_cap, - state.row_states[1].max, state.row_states[1].d, - v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[1]); - } - if (position <= last_pos[2]) { - // Past the last position, x2 doesn't count. - float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0, - k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x2, activations.config.att_cap, - state.row_states[2].max, state.row_states[2].d, - v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[2]); - } - if (position <= last_pos[3]) { - // Past the last position, x3 doesn't count. - float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0, - k.Row(k_pos), qkv_dim); - SingleFlashAttentionStep(x3, activations.config.att_cap, - state.row_states[3].max, state.row_states[3].d, - v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[3]); - } - ++position; - } return state; } -// Rounds n to a number that can be used as the number of Q rows in a tile -// of flash attention. -static size_t RoundToSuitablePowerOf2(size_t n) { - if (n < 4) return 1; - if (n < 8) return 4; - if (n < 16) return 8; - if (n < 32) return 16; - return 32; -} - // The vertical tile size is determined by the ability to use tiling and the // target_parallelism. In practice the possible tile sizes in order of -// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or -// 16. The final tile size is chosen to be the largest possible that allows -// for target_parallelism parallel tasks. +// preference for efficiency are 8, 4, 1. The final tile size is chosen to be +// the largest possible that allows for target_parallelism parallel tasks. size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, size_t total_tasks, size_t target_parallelism) { - const size_t kMaxEqualK = - RoundToSuitablePowerOf2(num_head_groups * num_tokens); - const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; - return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) - ? kNF - : std::min(kMinTileSize, kMaxEqualK); + const size_t kMaxEqualK = num_head_groups * num_tokens; + if (total_tasks / k8xNFVTileSize >= target_parallelism && + kMaxEqualK >= k8xNFVTileSize && kNF >= k8xNFVTileSize) { + return k8xNFVTileSize; + } + if (total_tasks / k4xNFVTileSize >= target_parallelism && + kMaxEqualK >= k4xNFVTileSize && kNF >= k4xNFVTileSize) { + return k4xNFVTileSize; + } + return 1; +} + +// Clears and fills the params vector with FlashAttentionParams for the given +// num_tokens, target_parallelism, and layer_idx. Computes tile sizes and +// offsets for each tile to achieve target_parallelism. +void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, + size_t layer_idx, AttentionActivationsPtrs& activations, + QBatch& qbatch, AttentionImpl attention_impl, + std::vector& params) { + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const hwy::Divisor div_qbatch(qbatch.Size()); + const size_t qkv_dim = layer_config.qkv_dim; + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); + const size_t total_tasks = token_batch * layer_config.heads; + size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks, + target_parallelism); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + // To maximize adjacent tasks with the same kv matrices, task index is encoded + // thus: [qi][kv_head][batch_idx][head_group]. Note that the head index is + // split into kv_head and head_group, since the head_group does not affect + // the KV matrices, and kv_head does. batch_idx does not affect the KV + // matrices, but does affect the last position in the sequence. qi affects + // everything. + params.clear(); + for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) { + for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) { + const size_t head_offset = kv_head * qkv_dim * 2; + const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset; + params.push_back(FlashAttentionParams{ + .qi_index = qi, + .kv_offset = kv_offset, + }); + for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + size_t last = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last) { + // last_pos is inclusive. + last = prefix_end - 1; + } + for (size_t head_group = 0; head_group < kHeadGroups; ++head_group) { + size_t tasks_remaining = kHeadGroups - head_group + + kHeadGroups * (num_tokens - 1 - batch_idx); + // We want to fill a tile of size kVTileSize or k4xNFVTileSize if + // smaller, otherwise everything is singles to the next head group. + size_t tasks_required = params.back().v_tile_size < k4xNFVTileSize + ? k4xNFVTileSize + : kVTileSize; + if (params.back().v_tile_size + tasks_remaining < tasks_required || + params.back().v_tile_size == kVTileSize) { + // We don't have enough tasks remaining to fill a tile, or the + // current tile is full so start new tile. + params.push_back(FlashAttentionParams{ + .qi_index = qi, + .kv_offset = kv_offset, + }); + } + const size_t head = head_group + kHeadGroups * kv_head; + const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; + auto& param = params.back(); + size_t offset = param.v_tile_size; + param.q_offsets[offset] = activations.q_bf.Row(tq_idx) + + head * qkv_dim - activations.q_bf.Row(0); + param.out_offsets[offset] = activations.att_out.Row(tq_idx) + + head * qkv_dim - + activations.att_out.Row(0); + param.tq_idx[offset] = tq_idx; + param.start_pos[offset] = start_pos; + param.min_start_pos = HWY_MIN(param.min_start_pos, start_pos); + param.last_pos[offset] = last; + param.max_last_pos = HWY_MAX(param.max_last_pos, last); + ++param.v_tile_size; + } + } + } + } +} + +// Returns the maximum number of tiles needed for any query in the batch. +size_t GetMaxTiles(const std::vector& params, + const size_t kHTileSize) { + size_t max_tiles = 0; + for (const auto& param : params) { + size_t start = param.min_start_pos / kHTileSize; + size_t last = param.max_last_pos / kHTileSize; + max_tiles = HWY_MAX(last + 1 - start, max_tiles); + } + return max_tiles; +} + +// Splits params into smaller k-strips to allow for more parallelism. +// The strips are of size num_tiles_per_task * kHTileSize. +// split_params is cleared and filled with the split tasks. +void SplitTasksByKPos(std::vector& params, + const size_t kHTileSize, const size_t num_tiles_per_task, + const size_t out_stride, + std::vector& split_params) { + split_params.clear(); + for (auto& param : params) { + param.split_index = split_params.size(); + size_t start = param.min_start_pos / kHTileSize; + size_t last = param.max_last_pos / kHTileSize; + for (size_t tile_pos = start; tile_pos <= last; + tile_pos += num_tiles_per_task) { + auto& split_param = split_params.emplace_back(param); + split_param.i_of_n = (tile_pos - start) / num_tiles_per_task; + uint32_t tile_last = (tile_pos + num_tiles_per_task) * kHTileSize - 1; + if (tile_last < param.max_last_pos) { + split_param.max_last_pos = tile_last; + for (auto& last_pos : split_param.last_pos) { + last_pos = std::min(last_pos, tile_last); + } + } + uint32_t tile_start = tile_pos * kHTileSize; + if (tile_start > param.min_start_pos) { + split_param.min_start_pos = tile_start; + for (auto& start_pos : split_param.start_pos) { + start_pos = std::max(start_pos, tile_start); + } + } + if (split_param.i_of_n > 0) { + for (size_t i = 0; i < split_param.v_tile_size; ++i) { + split_param.tq_idx[i] = + param.tq_idx[i] * AttentionActivations::kThreadReplicationFactor + + split_param.i_of_n - 1; + split_param.out_offsets[i] = + param.out_offsets[i] + + (split_param.tq_idx[i] - param.tq_idx[i]) * out_stride; + } + } + } + } +} + +// Clears and fills activations.flash_params with FlashAttentionParams for the +// given num_tokens, target_parallelism, and layer_idx. Computes tile sizes and +// offsets for each tile to achieve target_parallelism. +// If the parallelism is insufficient for this processor type, and the sequence +// length is sufficient, the tiles are upgraded to k4xNFVTileSize and the tasks +// are split along the k positions to achieve the desired parallelism. +// If splitting was required, returns that factor by which the tiles were +// upgraded, k4xNFVTileSize, otherwise returns 0. +uint32_t ComputeAndSplitFlashParams(const size_t kNF, const size_t num_tokens, + const size_t target_parallelism, + size_t layer_idx, + AttentionActivationsPtrs& activations, + QBatch& qbatch, ThreadingContext& ctx, + AttentionImpl attention_impl) { + ComputeFlashParams(num_tokens, target_parallelism, layer_idx, activations, + qbatch, attention_impl, activations.flash_params); + if (activations.flash_params.size() < ctx.pools.MaxWorkers()) { + // Insufficient parallelism for this processor type. Try splitting along the + // k positions. + size_t max_tiles = GetMaxTiles(activations.flash_params, kNF); + size_t desired_tiles_per_task = hwy::DivCeil( + activations.flash_params.size() * max_tiles, ctx.pools.MaxWorkers()); + // The cost of combining split tasks is significant, so we want a minimum + // number of tiles per task, and we want to use k4xNFVTileSize if possible. + constexpr size_t kMinTilesPerTask = 4; + if (desired_tiles_per_task >= k4xNFVTileSize * kMinTilesPerTask) { + // We can afford to use k4xNFVTileSize vertically, so recompute params. + ComputeFlashParams(num_tokens, + activations.flash_params.size() / k4xNFVTileSize, + layer_idx, activations, qbatch, attention_impl, + activations.flash_params); + desired_tiles_per_task = + hwy::DivCeil(desired_tiles_per_task, k4xNFVTileSize); + SplitTasksByKPos(activations.flash_params, kNF, desired_tiles_per_task, + activations.att_out_reps.Stride(), + activations.split_flash_params); + return k4xNFVTileSize; + } + } + return 0; +} + +// Combines results from split tasks, processing kNumNF * NF qkv values where +// kNumNF can be 1 4 or 16. This enables the intermediate results to be held in +// registers, which speeds up the combination step significantly. +template +void CombineSplitTasks1416(hwy::Span params, + size_t tile_pos, size_t qkv_offset, + AttentionActivationsPtrs& activations) { + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + const size_t kNF = hn::Lanes(df); + float overall_m = params[0].end_state.row_states[tile_pos].max; + float overall_d = params[0].end_state.row_states[tile_pos].d; + float* HWY_RESTRICT att_out = + activations.att_out.Row(0) + params[0].out_offsets[tile_pos] + qkv_offset; + VF result_0 = hn::Load(df, att_out); + VF result_1, result_2, result_3, result_4, result_5, result_6, result_7; + VF result_8, result_9, result_10, result_11, result_12, result_13, result_14; + VF result_15; + if constexpr (kNumNF > 1) { + result_1 = hn::Load(df, att_out + kNF); + result_2 = hn::Load(df, att_out + 2 * kNF); + result_3 = hn::Load(df, att_out + 3 * kNF); + } + if constexpr (kNumNF == 16) { + result_4 = hn::Load(df, att_out + 4 * kNF); + result_5 = hn::Load(df, att_out + 5 * kNF); + result_6 = hn::Load(df, att_out + 6 * kNF); + result_7 = hn::Load(df, att_out + 7 * kNF); + result_8 = hn::Load(df, att_out + 8 * kNF); + result_9 = hn::Load(df, att_out + 9 * kNF); + result_10 = hn::Load(df, att_out + 10 * kNF); + result_11 = hn::Load(df, att_out + 11 * kNF); + result_12 = hn::Load(df, att_out + 12 * kNF); + result_13 = hn::Load(df, att_out + 13 * kNF); + result_14 = hn::Load(df, att_out + 14 * kNF); + result_15 = hn::Load(df, att_out + 15 * kNF); + } + for (size_t i = 1; i < params.size() && params[i].i_of_n > 0; ++i) { + float m = params[i].end_state.row_states[tile_pos].max; + float d = params[i].end_state.row_states[tile_pos].d; + float new_m = std::max(overall_m, m); + // Scale factor for existing total given the change in max. + float old_scale = overall_d * std::exp(overall_m - new_m); + // Scale factor for new group to add. + float new_scale = d * std::exp(m - new_m); + float new_d = old_scale + new_scale; + float one_over_d = 1.0f / new_d; + old_scale *= one_over_d; + new_scale *= one_over_d; + overall_m = new_m; + overall_d = new_d; + float* HWY_RESTRICT att_in = activations.att_out_reps.Row(0) + + params[i].out_offsets[tile_pos] + qkv_offset; + VF old_scale_vec = hn::Set(df, old_scale); + VF new_scale_vec = hn::Set(df, new_scale); + result_0 = hn::Mul(result_0, old_scale_vec); + result_0 = hn::MulAdd(hn::Load(df, att_in), new_scale_vec, result_0); + if constexpr (kNumNF > 1) { + result_1 = hn::Mul(result_1, old_scale_vec); + result_2 = hn::Mul(result_2, old_scale_vec); + result_3 = hn::Mul(result_3, old_scale_vec); + result_1 = + hn::MulAdd(hn::Load(df, att_in + kNF), new_scale_vec, result_1); + result_2 = + hn::MulAdd(hn::Load(df, att_in + 2 * kNF), new_scale_vec, result_2); + result_3 = + hn::MulAdd(hn::Load(df, att_in + 3 * kNF), new_scale_vec, result_3); + } + if constexpr (kNumNF == 16) { + result_4 = hn::Mul(result_4, old_scale_vec); + result_5 = hn::Mul(result_5, old_scale_vec); + result_6 = hn::Mul(result_6, old_scale_vec); + result_7 = hn::Mul(result_7, old_scale_vec); + result_8 = hn::Mul(result_8, old_scale_vec); + result_9 = hn::Mul(result_9, old_scale_vec); + result_10 = hn::Mul(result_10, old_scale_vec); + result_11 = hn::Mul(result_11, old_scale_vec); + result_12 = hn::Mul(result_12, old_scale_vec); + result_13 = hn::Mul(result_13, old_scale_vec); + result_14 = hn::Mul(result_14, old_scale_vec); + result_15 = hn::Mul(result_15, old_scale_vec); + result_4 = + hn::MulAdd(hn::Load(df, att_in + 4 * kNF), new_scale_vec, result_4); + result_5 = + hn::MulAdd(hn::Load(df, att_in + 5 * kNF), new_scale_vec, result_5); + result_6 = + hn::MulAdd(hn::Load(df, att_in + 6 * kNF), new_scale_vec, result_6); + result_7 = + hn::MulAdd(hn::Load(df, att_in + 7 * kNF), new_scale_vec, result_7); + result_8 = + hn::MulAdd(hn::Load(df, att_in + 8 * kNF), new_scale_vec, result_8); + result_9 = + hn::MulAdd(hn::Load(df, att_in + 9 * kNF), new_scale_vec, result_9); + result_10 = + hn::MulAdd(hn::Load(df, att_in + 10 * kNF), new_scale_vec, result_10); + result_11 = + hn::MulAdd(hn::Load(df, att_in + 11 * kNF), new_scale_vec, result_11); + result_12 = + hn::MulAdd(hn::Load(df, att_in + 12 * kNF), new_scale_vec, result_12); + result_13 = + hn::MulAdd(hn::Load(df, att_in + 13 * kNF), new_scale_vec, result_13); + result_14 = + hn::MulAdd(hn::Load(df, att_in + 14 * kNF), new_scale_vec, result_14); + result_15 = + hn::MulAdd(hn::Load(df, att_in + 15 * kNF), new_scale_vec, result_15); + } + } + hn::Store(result_0, df, att_out); + if constexpr (kNumNF > 1) { + hn::Store(result_1, df, att_out + kNF); + hn::Store(result_2, df, att_out + 2 * kNF); + hn::Store(result_3, df, att_out + 3 * kNF); + } + if constexpr (kNumNF == 16) { + hn::Store(result_4, df, att_out + 4 * kNF); + hn::Store(result_5, df, att_out + 5 * kNF); + hn::Store(result_6, df, att_out + 6 * kNF); + hn::Store(result_7, df, att_out + 7 * kNF); + hn::Store(result_8, df, att_out + 8 * kNF); + hn::Store(result_9, df, att_out + 9 * kNF); + hn::Store(result_10, df, att_out + 10 * kNF); + hn::Store(result_11, df, att_out + 11 * kNF); + hn::Store(result_12, df, att_out + 12 * kNF); + hn::Store(result_13, df, att_out + 13 * kNF); + hn::Store(result_14, df, att_out + 14 * kNF); + hn::Store(result_15, df, att_out + 15 * kNF); + } +} + +// Recombines results from split tasks, activations.att_out_reps -> +// activations.att_out. Instead of repeatedly calling MultiplyByConstAndAdd, +// which reads/writes the sum each time, the result is kept entirely in +// registers, and the task is split into 16NF, 4NF, and NF chunks, so that there +// are enough registers to hold the intermediate results. +void CombineSplitTasks(size_t qkv_dim, uint32_t tile_factor, + AttentionActivationsPtrs& activations, + ThreadingContext& ctx) { + GCPP_ZONE(ctx, 0, Zones::kFlashAttentionCombineSplit); + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + uint32_t num_16 = qkv_dim / (16 * kNF); + uint32_t num_4 = (qkv_dim - kNF * 16 * num_16) / (4 * kNF); + uint32_t num_1 = hwy::DivCeil(qkv_dim - kNF * (16 * num_16 + 4 * num_4), kNF); + uint32_t tasks_per_qkv = num_16 + num_4 + num_1; + ParallelFor( + Parallelism::kFlat, + activations.flash_params.size() * tasks_per_qkv * tile_factor, ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t p, size_t worker) { + uint32_t tile = p / tasks_per_qkv; + uint32_t p_idx = + activations.flash_params[tile / tile_factor].split_index; + const auto& param = activations.split_flash_params[p_idx]; + size_t remaining_params = activations.split_flash_params.size() - p_idx; + tile %= tile_factor; + if (tile >= param.v_tile_size) return; + int32_t qkv_task = p % tasks_per_qkv; + if (qkv_task < num_16) { + uint32_t qkv_offset = qkv_task * 16 * kNF; + CombineSplitTasks1416<16>( + hwy::Span(¶m, remaining_params), + tile, qkv_offset, activations); + } else if (qkv_task < num_16 + num_4) { + uint32_t qkv_offset = (num_16 * 16 + (qkv_task - num_16) * 4) * kNF; + CombineSplitTasks1416<4>( + hwy::Span(¶m, remaining_params), + tile, qkv_offset, activations); + } else { + uint32_t qkv_offset = + (num_16 * 16 + num_4 * 4 + (qkv_task - num_16 - num_4)) * kNF; + CombineSplitTasks1416<1>( + hwy::Span(¶m, remaining_params), + tile, qkv_offset, activations); + } + }); } // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] @@ -756,49 +1236,28 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, // the one row of O takes L(4D+3) reads and L(D+3) writes. // For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. // -// Flash attention fuses these operations together, and has 3 operating modes: -// 1. NF rows of the result computed using tiles of registers of shape NFx8. -// 2. 4 rows of the result computed using tiles of registers of shape 4xNF. -// 3. One row (of Q and the result) at a time. -// In all cases the intermediate result (Q.KT) is never stored to memory. -// NF is the number of float lanes in a register, being 16 for AVX3. The softmax -// is converted to streaming form using the algorithm from: -// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. -// Q is transposed to Q_T[D,L] to make the dot product computation efficient. -// -// In mode 1: -// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one -// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is -// computed entirely in registers, and a further NF registers to accumulate the -// results of the product of the softmax and V, reduce the number of reads of V -// by NF, and the reads/writes of O by 8. -// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, -// which on AVX3 is an overall reduction by about a factor of 10. -// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn -// prefill, since in other cases, there is either a single K timestep (prefill) -// or a single num_heads set of Q rows (decode). -// -// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile, -// reducing the reads of Q by NF, and the reads of K by 4. The softmax and -// accumulation of the result is done in registers, cutting the reads of V by 4. -// The reads/writes of O are reduced by a factor of NF. -// The overall reduction is limited by the need to use gather to load K. -// Transposing K would be possible, but is complicated by the wraparound. -// Mode 2 can be used in all cases when there are at least 4 attention heads, -// but it may be prefereable to use mode 3 when the batch size is small to -// maximise parallelism. -// -// In mode 3, a single row of Q is computed against a single K timestep at a -// time, using SingleFlashAttention. In this case there is no reduction in the -// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are -// still eliminated. +// Flash attention fuses these operations together, and operates on tiles of +// n Q rows x NF K positions, accumulated in n registers, where n is in +// {1, 4, 8} and NF is the number of float lanes in a register, being 16 for +// AVX3. This reduces the number of reads of Q by NF and reads of K by n. The +// softmax is converted to streaming form using the algorithm from: +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf, +// which eliminates the need to store A to memory. The accumulated Q.KT result +// is passed via the streaming softmax directly to the A.V step. +// To make the dot product computation more efficient, Q, K, and V are stored +// as BF16 and K is transposed to shape: +// [seq_len / NF, layers * kv_heads * qkv_dim/2 * NF * 2], where the /2, *2 +// represents that pairs of qkv_dim elements are kept together to make best +// use of BF16 dot product instructions, where each pair of adjacent BF16 +// values from Q and K are mul-added into a single F32 result. // // A further complication is that real attention is not as simple as documented // in the paper and above. There are multiple query heads, differing KV, and -// different sequence lengths, so a lot of the work in FlashAttention is making -// sure that a collection of q rows with the same KV and sequence length are -// grouped together so that mode 1 or 2 can be used, and choosing which of the -// 3 modes to use for best efficiency. +// different sequence lengths, and the difference between prefill and decode, +// so a lot of the work in FlashAttention is making sure that a collection of q +// rows with the same KV and sequence length are grouped together so that the +// largest possible tiles can be used. This is dealt with by the +// ComputeAndSplitFlashParams() function. void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const size_t layer_idx, const MatPtr& query_norm_scale, AttentionActivationsPtrs& activations, QBatch& qbatch, @@ -806,8 +1265,16 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, query_norm_scale, layer_idx, activations, ctx); - const hwy::Divisor div_qbatch(qbatch.Size()); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); // Compress q to q_bf. + // TODO(rays): Move this into RMSNormAndPositionalEncoding(). ParallelFor( Parallelism::kWithinCluster, activations.q.Rows(), ctx, /*cluster_idx=*/0, Callers::kFlashAttention, @@ -818,168 +1285,53 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, df, activations.q.Row(row), activations.q.Cols(), tls, MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0); }); - const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; - const size_t qkv_dim = layer_config.qkv_dim; - - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); - const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); - const size_t total_tasks = token_batch * layer_config.heads; - - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); - constexpr size_t kMaxNF = hn::MaxLanes(df); - HWY_DASSERT(kNF <= kMaxNF); - const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, - total_tasks, target_parallelism); - // Only transpose Q if we are using tiling. - if (kVTileSize == kNF) { - size_t max_last = 0, min_start = std::numeric_limits::max(); - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - size_t pos = qbatch.Pos(qi); - const size_t start = StartPos(pos, activations.config, layer_idx); - pos += num_tokens - 1; - const size_t end = qbatch.PrefixEnd(qi); - if (end > 0 && end - 1 > pos) { - pos = end - 1; - } - max_last = std::max(max_last, pos); - min_start = std::min(min_start, start); - } - if (max_last - min_start + 1 >= kNFx8HTileSize) { - // q has shape [batch, qbatch][head, qkv_dim]. - // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the - // maximum possible number of consecutive columns have the same KV - // matrices. Each thread will process a tile of NF columns of QT so the - // starting column index of QT is just the task index * kVTileSize. - TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); - } - } - const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); - const hwy::Divisor div_tokens(num_tokens); - // All layers should have the same number of heads. - HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + int tile_factor = + ComputeAndSplitFlashParams(kNF, num_tokens, target_parallelism, layer_idx, + activations, qbatch, ctx, attention_impl); + auto& params = tile_factor >= 1 ? activations.split_flash_params + : activations.flash_params; + size_t num_tasks = params.size(); // For each head/token/query, compute fused flash Q.K, softmax and weighted V. const auto func = [&](const size_t task, size_t worker) HWY_ATTR { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); - // Offsets into original Q for each row in the tile. - uint32_t q_offsets[kMaxNF]; - // Offsets into att_out for each row in the tile. - uint32_t out_offsets[kMaxNF]; - // Start positions for each row in the tile. - size_t start_positions[kMaxNF]; - // Last positions for each row in the tile. Inclusive. - uint32_t last_pos[kMaxNF]; - // min and max last positions across all rows in the tile determines when - // TileFlashAttention switches to single vector mode to handle the - // ragged sequence lengths. - size_t min_last_pos = std::numeric_limits::max(); - size_t max_last_pos = 0; - // Indices into the qbatch.KV for each row in the tile. - size_t qi_indices[kMaxNF]; - // Indices into the kv_cache for each row in the tile. - size_t kv_offsets[kMaxNF]; - // first_task is [qbatch, head, token]. - const size_t first_task = task * kVTileSize; - const size_t last_task = first_task + kVTileSize - 1; - bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks; - for (size_t offset = 0; - offset < kVTileSize && first_task + offset < total_tasks; ++offset) { - const size_t batch_idx = div_tokens.Remainder(first_task + offset); - const size_t qh = div_tokens.Divide(first_task + offset); - const size_t head = activations.div_heads.Remainder(qh); - const size_t qi = activations.div_heads.Divide(qh); - const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; - qi_indices[offset] = qi; - - // Find the token position in the query and calculate - // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; - const size_t start_pos = StartPos(pos, activations.config, layer_idx); - start_positions[offset] = start_pos; - size_t last = pos; - const size_t prefix_end = qbatch.PrefixEnd(qi); - if (prefix_end > 0 && prefix_end - 1 > last) { - // last_pos in `TileFlashAttention` is inclusive. - last = prefix_end - 1; - } - last_pos[offset] = last; - min_last_pos = HWY_MIN(min_last_pos, last); - max_last_pos = HWY_MAX(max_last_pos, last); - q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim - - activations.q_bf.Row(0); - out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - - activations.att_out.Row(0); - const size_t kv_index = head / kHeadGroups; - const size_t head_offset = kv_index * qkv_dim * 2; - kv_offsets[offset] = layer_idx * cache_layer_size + head_offset; - // If any of the parameters in this if statement differ within this task, - // then we can't use TileFlashAttention. TileFlashAttention requires that - // all rows in the tile have the same K and V matrices, and Q starts at - // the same position. The end positions do not have to be the equal. - if (start_positions[offset] != start_positions[0] || - qi_indices[offset] != qi_indices[0] || - kv_offsets[offset] != kv_offsets[0]) { - use_tile_attention = false; - } - } - for (size_t offset = 0; - offset < kVTileSize && first_task + offset < total_tasks; ++offset) { - auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); - k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); - v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim, - kv_cache.Stride()); - if (use_tile_attention) { - // To avoid duplicating the code to setup K and V, the call to - // TileFlashAttention is inside the loop over tasks, even though it - // handles all rows in the task at once. - StridedView qT = - StridedView(activations.q_T.Row(0) + first_task, kVTileSize, - activations.q_T.Stride()); - if (kVTileSize == kNF) { - // We can still use TileFlashAttention even if we didn't transpose Q - // above. The condition used for transposing Q above is more general - // and easier to compute than the condition used within - // TileFlashAttention that min_last_pos - start_positions[offset] < - // kNFx8HTileSize. In this case, qT is never used. Some tasks might - // use qT and some might not, which is why the more general condition - // is used above to catch all cases where qT will be used. - TileFlashAttention(activations.q_bf, q_offsets, qT, k, - start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, activations, - activations.att_out, out_offsets, ctx, worker); - } else if (kVTileSize == 4) { - TileFlashAttention4(activations.q_bf, q_offsets, k, - start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, activations, - activations.att_out, out_offsets, ctx, worker); - } else { - HWY_UNREACHABLE; - } - break; - } else { - SingleFlashAttention(start_positions[offset], last_pos[offset], - activations.q_bf.Row(0) + q_offsets[offset], k, v, - layer_idx, activations, - activations.att_out.Row(0) + out_offsets[offset], - ctx, worker); - } + auto& param = params[task]; + auto& kv_cache = qbatch.KV(param.qi_index).kv_cache; + auto& kT_cache = qbatch.KV(param.qi_index).k_cache; + MatPtrT kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), + qkv_dim * 2 * kNF)); + kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + param.kv_offset + qkv_dim, kv_cache.Stride()); + auto& vT_cache = qbatch.KV(param.qi_index).v_cache; + MatPtrT vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), + qkv_dim * 2 * kNF)); + vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride()); + MatPtrT& att_out = + param.i_of_n == 0 ? activations.att_out : activations.att_out_reps; + if (param.v_tile_size == k8xNFVTileSize) { + param.end_state = TileFlashAttention148( + param, activations.q_bf, kT, vT, layer_idx, activations, att_out, + qkv_dim, ctx, worker, attention_impl); + } else if (param.v_tile_size == k4xNFVTileSize) { + param.end_state = TileFlashAttention148( + param, activations.q_bf, kT, vT, layer_idx, activations, att_out, + qkv_dim, ctx, worker, attention_impl); + } else { + param.end_state = TileFlashAttention148<1>( + param, activations.q_bf, kT, vT, layer_idx, activations, att_out, + qkv_dim, ctx, worker, attention_impl); } }; { PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention, - func); + HierarchicalParallelFor(num_tasks, ctx, Callers::kFlashAttention, func); + } + if (tile_factor >= 1) { + // Run the flash attention correction on the partial outputs. + CombineSplitTasks(qkv_dim, tile_factor, activations, ctx); } } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 81bfcdf..6466674 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -44,15 +44,6 @@ namespace gcpp { float* HWY_RESTRICT att_out, \ ThreadingContext& ctx, size_t worker); \ \ - Tile4FlashState TileFlashAttention4( \ - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ - const MatPtrT& k, size_t start_pos, \ - const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ - size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ - const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ - ThreadingContext& ctx, const size_t worker); \ - \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t total_tasks, size_t target_parallelism); \ \ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 6fbaa5f..85e73a8 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -62,16 +62,17 @@ namespace HWY_NAMESPACE { using FloatPtr = hwy::AlignedFreeUniquePtr; -void SetMat(const size_t offset, MatPtrT& mat) { +template +void SetMat(const size_t offset, MatPtrT& mat) { const size_t kOuter = mat.Extents().rows; const size_t kInner = mat.Extents().cols; const float i_scale = 1.0f / kInner; const float j_scale = 1.0f / kOuter; for (size_t i = 0; i < kOuter; ++i) { - float* row = mat.Row(i); + T* row = mat.Row(i); for (size_t j = 0; j < kInner; ++j) { - row[j] = - static_cast((i * kInner * i_scale + (j + offset) * j_scale)); + row[j] = hwy::ConvertScalarTo( + static_cast((i * kInner * i_scale + (j + offset) * j_scale))); } } } @@ -94,14 +95,15 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { if (rel_abs_delta > 0.0f) { rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); } - EXPECT_LT(rel_abs_delta, 1e-5) + EXPECT_LT(rel_abs_delta, 1e-3) << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," << c << "]=" << b_row[c]; } } } -void TestFlashAttention(size_t target_parallelism) { +void TestFlashAttention(size_t target_parallelism, + AttentionImpl attention_impl) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); constexpr size_t kOuter = 1024; @@ -131,7 +133,8 @@ void TestFlashAttention(size_t target_parallelism) { const size_t batch_size = kOuter; std::vector> row_ptrs; AttentionActivations attention_storage(config, layer_config, batch_size, - kOuter, runtime_config, ctx.allocator, + kOuter, runtime_config, + ctx.pools.MaxWorkers(), ctx.allocator, row_ptrs); AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; @@ -142,7 +145,10 @@ void TestFlashAttention(size_t target_parallelism) { const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t seq_len = static_cast(attention.div_seq_len.GetDivisor()); + MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache); + MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache); auto& kvc = qbatch.KV(0).kv_cache; + const size_t kFloatsPerTile = 2 * FloatsPerVector(); for (size_t h = 0; h < layer_config.heads; ++h) { // Make strided views into the kv cache for // this query and head. @@ -153,6 +159,17 @@ void TestFlashAttention(size_t target_parallelism) { v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride()); SetMat(h + layer_config.heads, k); SetMat(h + layer_config.heads * 2, v); + for (size_t p = 0; p < tokens.size(); ++p) { + KV_t* HWY_RESTRICT k_src = k.Row(p); + KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + + head_offset * kFloatsPerTile / 2 + + p % kFloatsPerTile * 2; + KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + + head_offset * kFloatsPerTile / 2 + + p % kFloatsPerTile * kFloatsPerTile; + + TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim); + } } SetMat(1, attention.q); DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention, @@ -167,18 +184,19 @@ void TestFlashAttention(size_t target_parallelism) { tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), total_tasks, target_parallelism); - printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", - target_parallelism, kNF, kVTileSize); + printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n", + target_parallelism, kNF, kVTileSize, + GetAttentionImplName(attention_impl).c_str()); FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale, - attention, qbatch, ctx, AttentionImpl::kFlash); + attention, qbatch, ctx, attention_impl); AssertClose(attention.att_out, *saved_att); ctx.profiler.PrintResults(); } void TestAttention() { - TestFlashAttention(8192); - TestFlashAttention(2048); - TestFlashAttention(256); + TestFlashAttention(8192, AttentionImpl::kFlash); + TestFlashAttention(2048, AttentionImpl::kFlash); + TestFlashAttention(256, AttentionImpl::kFlash); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h index 73563fe..6e35a4d 100644 --- a/gemma/flash_structs.h +++ b/gemma/flash_structs.h @@ -2,11 +2,19 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ #include +#include #include namespace gcpp { +// The vertical tile size in flash attention when register lanes correspond to +// K-timesteps, and the number of registers is 4 for 4 Q-rows. +static constexpr size_t k4xNFVTileSize = 4; +// The vertical tile size in flash attention when register lanes correspond to +// K-timesteps, and the number of registers is 8 for 8 Q-rows. +static constexpr size_t k8xNFVTileSize = 8; + // State for computing softmax in a streaming ("online") manner, // avoiding large intermediate values by subtracting the running maximum. // For a sequence x_1, ..., x_n: @@ -20,10 +28,44 @@ struct OnlineSoftmaxState { float d = 0.0f; }; -static constexpr size_t kVTileSize4 = 4; - struct Tile4FlashState { - OnlineSoftmaxState row_states[kVTileSize4]; + OnlineSoftmaxState row_states[k8xNFVTileSize]; +}; + +// Parameters for a strip of tiles of flash attention. For processing a strip +// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF +// k-positions. The total width of the strip might cover the entire sequence, +// or a part of it, depending on whether the strip has been split. +struct FlashAttentionParams { + // Vertical tile size gives the number used in the k8xNFVTileSize arrays. + // It is the number of Q rows in the tile. + uint32_t v_tile_size = 0; + // min start position across all rows in the tile determines the + // mask used for the tile. + uint32_t min_start_pos = std::numeric_limits::max(); + // max last position across all rows in the tile determines the mask + // used for the tile. + uint32_t max_last_pos = 0; + // Index into the qbatch.KV is the same for each row in the tile. + uint32_t qi_index; + // Index into the kv_cache is the same for each row in the tile. + uint32_t kv_offset; + // In the original task, the index to the split tasks of the first split task. + uint32_t split_index = 0; + // The index of the split for running split attention. + uint32_t i_of_n = 0; + // Offsets into original Q for each row in the tile. + uint32_t q_offsets[k8xNFVTileSize]; + // Offsets into att_out for each row in the tile. + uint32_t out_offsets[k8xNFVTileSize]; + // Start k-positions for each row in the tile. + uint32_t start_pos[k8xNFVTileSize]; + // Last k-positions for each row in the tile. Inclusive. + uint32_t last_pos[k8xNFVTileSize]; + // Row index to att_out. + uint32_t tq_idx[k8xNFVTileSize]; + // Flash attention state for the tile. + Tile4FlashState end_state; }; } // namespace gcpp diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 49276f8..e241c34 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -43,6 +43,17 @@ static size_t CappedSeqLen(const ModelConfig& config, KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), + // WARNING: the rows and cols of k_cache and v_cache will be modified + // before use! + // The rows will be reduced by a factor of 2xkFloatsPerVector, and the + // cols will be increased by 2xkFloatsPerVector on first use. This is to + // avoid making KVCache another class that has to be duplicated for each + // machine architecture, since kFloatsPerVector is architecture dependent. + // The change is shape is safe only if the padding is kPacked. + k_cache("k", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator, + MatPadding::kPacked), + v_cache("v", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator, + MatPadding::kPacked), allocator_(allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, @@ -55,6 +66,8 @@ KVCache KVCache::Copy() { KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); + CopyMat(k_cache, copy.k_cache); + CopyMat(v_cache, copy.v_cache); return copy; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index fe6a1ff..3d5d821 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -30,7 +30,7 @@ namespace gcpp { -using KV_t = float; +using KV_t = BF16; // A non-owning view of a KVCache. struct KVCachePtr { @@ -38,6 +38,8 @@ struct KVCachePtr { size_t SeqLen() const; MatPtrT kv_cache; + MatPtrT k_cache; + MatPtrT v_cache; }; struct KVCache { @@ -52,10 +54,33 @@ struct KVCache { } MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + // The format of k_cache indicates that there are pairs of values from + // qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence, + // in groups of qkv_dim/2 elements in groups of kv_heads elements. + // This enables sequential loading of the data when filling 2 vectors with + // NF sequence elements of pairs of BF16 qkv values. The next vector then + // continues reading the rest of qkv. + // [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2] + MatStorageT k_cache; + // v_cache is formatted to allow sequential access to V during scaling and + // update of att_out. + // Originally [seq_len, layers * kv_heads * qkv_dim] + // v_cache is transposed to: + // [layers, kv_heads, seq_len, qkv_dim], reshaped to: + // [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF] + // then transposed to: + // [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF] + // and finally packed in a 2D MatStorageT as: + // [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF] + // This allows sequential reads of 2NF registers each of 2NF BF16 values, + // repeatedly until all of qkv_dim is read. + MatStorageT v_cache; KVCachePtr ToPtr() { return KVCachePtr{ .kv_cache = kv_cache, + .k_cache = k_cache, + .v_cache = v_cache, }; } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0eeec31..c68b6c5 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -614,267 +614,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, }); } -template , HWY_IF_V_SIZE_GT_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, - VF& sum1, VF& sum2, VF& sum3, VF& sum4, - VF& sum5, VF& sum6, VF& sum7, VF& sum8, - VF& sum9, VF& sum10, VF& sum11, - VF& sum12, VF& sum13, VF& sum14, - VF& sum15) { - sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); - sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); - sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); - sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); - sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); - sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); - sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); - sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); - sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale)); - sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale)); - sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale)); - sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale)); - sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale)); - sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale)); - sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale)); - sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale)); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, - VF& sum1, VF& sum2, VF& sum3, VF& sum4, - VF& sum5, VF& sum6, VF& sum7, VF& sum8, - VF& sum9, VF& sum10, VF& sum11, - VF& sum12, VF& sum13, VF& sum14, - VF& sum15) {} - -template , HWY_IF_V_SIZE_GT_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& sum4, VF& sum5, - VF& sum6, VF& sum7) { - sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); - sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); - sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); - sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); - sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); - sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); - sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); - sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& sum4, VF& sum5, - VF& sum6, VF& sum7) {} - -template , HWY_IF_V_SIZE_GT_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( - DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, - VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, - VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { - sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); - sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); - sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); - sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); - sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); - sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); - sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); - sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); - sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8); - sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9); - sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10); - sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11); - sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12); - sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13); - sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14); - sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( - DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, - VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, - VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {} - -template , HWY_IF_V_SIZE_GT_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, - VF& sum0, VF& sum1, VF& sum2, VF& sum3, - VF& sum4, VF& sum5, VF& sum6, - VF& sum7) { - sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); - sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); - sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); - sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); - sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); - sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); - sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); - sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, - VF& sum0, VF& sum1, VF& sum2, VF& sum3, - VF& sum4, VF& sum5, VF& sum6, - VF& sum7) {} - -template > -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, - VF& sum0, VF& sum1, VF& sum2, - VF& sum3) { - sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); - sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); - sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); - sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); -} - -// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows -// of V by the corresponding values in c0-c7 and adds them to NF rows of out, -// after first prescaling out by scale. -// The depth (size) must be a multiple of NF. -template > -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( - DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, - const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, - const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); - - size_t i = 0; - while (i + NF <= size) { - if HWY_LANES_CONSTEXPR (NF == 16) { - VF out0, out1, out2, out3, out4, out5, out6, out7; - VF out8, out9, out10, out11, out12, out13, out14, out15; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out4 = hn::Load(df, out + i + out_offsets[4]); - out5 = hn::Load(df, out + i + out_offsets[5]); - out6 = hn::Load(df, out + i + out_offsets[6]); - out7 = hn::Load(df, out + i + out_offsets[7]); - out8 = hn::Load(df, out + i + out_offsets[8]); - out9 = hn::Load(df, out + i + out_offsets[9]); - out10 = hn::Load(df, out + i + out_offsets[10]); - out11 = hn::Load(df, out + i + out_offsets[11]); - out12 = hn::Load(df, out + i + out_offsets[12]); - out13 = hn::Load(df, out + i + out_offsets[13]); - out14 = hn::Load(df, out + i + out_offsets[14]); - out15 = hn::Load(df, out + i + out_offsets[15]); - Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x0 = hn::Load(df, v.Row(pos[0]) + i); - MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x1 = hn::Load(df, v.Row(pos[1]) + i); - MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x2 = hn::Load(df, v.Row(pos[2]) + i); - MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x3 = hn::Load(df, v.Row(pos[3]) + i); - MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x4 = hn::Load(df, v.Row(pos[4]) + i); - MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x5 = hn::Load(df, v.Row(pos[5]) + i); - MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x6 = hn::Load(df, v.Row(pos[6]) + i); - MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x7 = hn::Load(df, v.Row(pos[7]) + i); - MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - hn::Store(out4, df, out + i + out_offsets[4]); - hn::Store(out5, df, out + i + out_offsets[5]); - hn::Store(out6, df, out + i + out_offsets[6]); - hn::Store(out7, df, out + i + out_offsets[7]); - hn::Store(out8, df, out + i + out_offsets[8]); - hn::Store(out9, df, out + i + out_offsets[9]); - hn::Store(out10, df, out + i + out_offsets[10]); - hn::Store(out11, df, out + i + out_offsets[11]); - hn::Store(out12, df, out + i + out_offsets[12]); - hn::Store(out13, df, out + i + out_offsets[13]); - hn::Store(out14, df, out + i + out_offsets[14]); - hn::Store(out15, df, out + i + out_offsets[15]); - } - if HWY_LANES_CONSTEXPR (NF == 8) { - VF out0, out1, out2, out3, out4, out5, out6, out7; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out4 = hn::Load(df, out + i + out_offsets[4]); - out5 = hn::Load(df, out + i + out_offsets[5]); - out6 = hn::Load(df, out + i + out_offsets[6]); - out7 = hn::Load(df, out + i + out_offsets[7]); - Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); - VF x0 = hn::Load(df, v.Row(pos[0]) + i); - MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); - VF x1 = hn::Load(df, v.Row(pos[1]) + i); - MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7); - VF x2 = hn::Load(df, v.Row(pos[2]) + i); - MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7); - VF x3 = hn::Load(df, v.Row(pos[3]) + i); - MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7); - VF x4 = hn::Load(df, v.Row(pos[4]) + i); - MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7); - VF x5 = hn::Load(df, v.Row(pos[5]) + i); - MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7); - VF x6 = hn::Load(df, v.Row(pos[6]) + i); - MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7); - VF x7 = hn::Load(df, v.Row(pos[7]) + i); - MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - hn::Store(out4, df, out + i + out_offsets[4]); - hn::Store(out5, df, out + i + out_offsets[5]); - hn::Store(out6, df, out + i + out_offsets[6]); - hn::Store(out7, df, out + i + out_offsets[7]); - } - if HWY_LANES_CONSTEXPR (NF == 4) { - VF out0, out1, out2, out3; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - VF x0 = hn::Load(df, v.Row(pos[0]) + i); - MulAdd4(df, x0, c0, out0, out1, out2, out3); - VF x1 = hn::Load(df, v.Row(pos[1]) + i); - MulAdd4(df, x1, c1, out0, out1, out2, out3); - VF x2 = hn::Load(df, v.Row(pos[2]) + i); - MulAdd4(df, x2, c2, out0, out1, out2, out3); - VF x3 = hn::Load(df, v.Row(pos[3]) + i); - MulAdd4(df, x3, c3, out0, out1, out2, out3); - VF x4 = hn::Load(df, v.Row(pos[4]) + i); - MulAdd4(df, x4, c4, out0, out1, out2, out3); - VF x5 = hn::Load(df, v.Row(pos[5]) + i); - MulAdd4(df, x5, c5, out0, out1, out2, out3); - VF x6 = hn::Load(df, v.Row(pos[6]) + i); - MulAdd4(df, x6, c6, out0, out1, out2, out3); - VF x7 = hn::Load(df, v.Row(pos[7]) + i); - MulAdd4(df, x7, c7, out0, out1, out2, out3); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - } - i += NF; - } - HWY_DASSERT(size == i); -} - template > HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, const VF c1, const VF c2, const VF c3, @@ -887,240 +626,134 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, } template > -HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT& v, - const size_t* HWY_RESTRICT pos, - const size_t offset, const VF c0, - const VF c1, const VF c2, - const VF c3, VF& sum0, VF& sum1, - VF& sum2, VF& sum3) { - // TODO(rays): Check whether a transpose of c0-c3 is applicable and faster. - VF x0 = hn::Load(df, v.Row(pos[0]) + offset); - MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1), - hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2, - sum3); - VF x1 = hn::Load(df, v.Row(pos[1]) + offset); - MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1), - hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2, - sum3); - VF x2 = hn::Load(df, v.Row(pos[2]) + offset); - MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1), - hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2, - sum3); - VF x3 = hn::Load(df, v.Row(pos[3]) + offset); - MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1), - hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2, - sum3); +HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4( + DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c, + const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a, + VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) { + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) { + VBF v0 = hn::Load(dbf, v); + VF c0 = hn::Set(df, *c++); + VF c1 = hn::Set(df, *c++); + VF c2 = hn::Set(df, *c++); + VF c3 = hn::Set(df, *c++); + VF v0a = hn::PromoteLowerTo(df, v0); + VF v0b = hn::PromoteUpperTo(df, v0); + MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a); + MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b); + } } -template , HWY_IF_V_SIZE_GT_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( - DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, - const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, - VF& sum0, VF& sum1, VF& sum2, VF& sum3) { - VF x4 = hn::Load(df, v.Row(pos[4]) + offset); - MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1), - hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2, - sum3); - VF x5 = hn::Load(df, v.Row(pos[5]) + offset); - MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1), - hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2, - sum3); - VF x6 = hn::Load(df, v.Row(pos[6]) + offset); - MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1), - hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2, - sum3); - VF x7 = hn::Load(df, v.Row(pos[7]) + offset); - MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1), - hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2, - sum3); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 31)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( - DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, - const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, - VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} - -template , HWY_IF_V_SIZE_GT_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( - DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, - const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, - VF& sum0, VF& sum1, VF& sum2, VF& sum3) { - VF x8 = hn::Load(df, v.Row(pos[8]) + offset); - MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1), - hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2, - sum3); - VF x9 = hn::Load(df, v.Row(pos[9]) + offset); - MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1), - hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2, - sum3); - VF x10 = hn::Load(df, v.Row(pos[10]) + offset); - MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1), - hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1, - sum2, sum3); - VF x11 = hn::Load(df, v.Row(pos[11]) + offset); - MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1), - hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1, - sum2, sum3); - VF x12 = hn::Load(df, v.Row(pos[12]) + offset); - MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1), - hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1, - sum2, sum3); - VF x13 = hn::Load(df, v.Row(pos[13]) + offset); - MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1), - hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1, - sum2, sum3); - VF x14 = hn::Load(df, v.Row(pos[14]) + offset); - MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1), - hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1, - sum2, sum3); - VF x15 = hn::Load(df, v.Row(pos[15]) + offset); - MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1), - hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1, - sum2, sum3); -} - -template , HWY_IF_V_SIZE_LE_D(DF, 63)> -HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( - DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, - const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, - VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} - -// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows -// of V by the corresponding values in c0-c3 and adds them to NF rows of out, +// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows +// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out, // after first prescaling out by scale. -// The depth (size) must be a multiple of NF. +// The depth (size) must be a multiple of 2NF. template > -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( - DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, - const VF c2, const VF c3, const MatPtrT& v, - const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, +HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem( + DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01, + const VF c10, const VF c11, const VF c20, const VF c21, const VF c30, + const VF c31, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + size_t num_lanes, float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + constexpr size_t kMaxNF = hn::MaxLanes(df); + const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF)); + HWY_DASSERT(pos[0] % (2 * NF) == 0); + HWY_ALIGN float c_mem[8 * kMaxNF]; + hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem); + hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF); size_t i = 0; - while (i + NF <= size) { - VF out0, out1, out2, out3; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::Set(df, scales[0])); - out1 = hn::Mul(out1, hn::Set(df, scales[1])); - out2 = hn::Mul(out2, hn::Set(df, scales[2])); - out3 = hn::Mul(out3, hn::Set(df, scales[3])); - MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); - if HWY_LANES_CONSTEXPR (NF >= 8) { - MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); - if HWY_LANES_CONSTEXPR (NF >= 16) { - MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, - out3); - } - } - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - i += NF; + while (i + NF * 2 <= size) { + VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b; + out0a = hn::Load(df, out + i + out_offsets[0]); + out1a = hn::Load(df, out + i + out_offsets[1]); + out2a = hn::Load(df, out + i + out_offsets[2]); + out3a = hn::Load(df, out + i + out_offsets[3]); + VF scale0 = hn::Set(df, scales[0]); + VF scale1 = hn::Set(df, scales[1]); + VF scale2 = hn::Set(df, scales[2]); + VF scale3 = hn::Set(df, scales[3]); + out0a = hn::Mul(out0a, scale0); + out1a = hn::Mul(out1a, scale1); + out2a = hn::Mul(out2a, scale2); + out3a = hn::Mul(out3a, scale3); + out0b = hn::Load(df, out + i + NF + out_offsets[0]); + out1b = hn::Load(df, out + i + NF + out_offsets[1]); + out2b = hn::Load(df, out + i + NF + out_offsets[2]); + out3b = hn::Load(df, out + i + NF + out_offsets[3]); + out0b = hn::Mul(out0b, scale0); + out1b = hn::Mul(out1b, scale1); + out2b = hn::Mul(out2b, scale2); + out3b = hn::Mul(out3b, scale3); + MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a, + out2a, out3a, out0b, out1b, out2b, out3b); + hn::Store(out0a, df, out + i + out_offsets[0]); + hn::Store(out1a, df, out + i + out_offsets[1]); + hn::Store(out2a, df, out + i + out_offsets[2]); + hn::Store(out3a, df, out + i + out_offsets[3]); + hn::Store(out0b, df, out + i + NF + out_offsets[0]); + hn::Store(out1b, df, out + i + NF + out_offsets[1]); + hn::Store(out2b, df, out + i + NF + out_offsets[2]); + hn::Store(out3b, df, out + i + NF + out_offsets[3]); + i += NF * 2; + v_bf += 4 * NF * NF; } HWY_DASSERT(size == i); } -// Prescales NF rows of out by scale, then multiplies 1 row of V by the -// corresponding values in c0 and adds them to the NF rows of out. -// The depth (size) must be a multiple of NF. template > -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( - DF df, const VF scale, const VF c0, const MatPtrT& v, - const size_t pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { +HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df, + const BF16* HWY_RESTRICT v, + const float* HWY_RESTRICT c, + const size_t num_lanes, + VF& sum0a, VF& sum0b) { + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const size_t kNF = hn::Lanes(df); + for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) { + VBF v0 = hn::Load(dbf, v); + VF c0 = hn::Set(df, *c++); + VF v0a = hn::PromoteLowerTo(df, v0); + VF v0b = hn::PromoteUpperTo(df, v0); + sum0a = hn::MulAdd(v0a, c0, sum0a); + sum0b = hn::MulAdd(v0b, c0, sum0b); + } +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem( + DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01, + const MatPtrT& v, const size_t* HWY_RESTRICT pos, size_t num_lanes, + float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets, + const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + constexpr size_t kMaxNF = hn::MaxLanes(df); + const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF)); + HWY_DASSERT(pos[0] % (2 * NF) == 0); + HWY_ALIGN float c_mem[2 * kMaxNF]; + hn::Store(c00, df, c_mem); + hn::Store(c01, df, c_mem + NF); size_t i = 0; - while (i + NF <= size) { - if HWY_LANES_CONSTEXPR (NF == 16) { - VF out0, out1, out2, out3, out4, out5, out6, out7; - VF out8, out9, out10, out11, out12, out13, out14, out15; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out4 = hn::Load(df, out + i + out_offsets[4]); - out5 = hn::Load(df, out + i + out_offsets[5]); - out6 = hn::Load(df, out + i + out_offsets[6]); - out7 = hn::Load(df, out + i + out_offsets[7]); - out8 = hn::Load(df, out + i + out_offsets[8]); - out9 = hn::Load(df, out + i + out_offsets[9]); - out10 = hn::Load(df, out + i + out_offsets[10]); - out11 = hn::Load(df, out + i + out_offsets[11]); - out12 = hn::Load(df, out + i + out_offsets[12]); - out13 = hn::Load(df, out + i + out_offsets[13]); - out14 = hn::Load(df, out + i + out_offsets[14]); - out15 = hn::Load(df, out + i + out_offsets[15]); - Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - VF x0 = hn::Load(df, v.Row(pos) + i); - MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, - out9, out10, out11, out12, out13, out14, out15); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - hn::Store(out4, df, out + i + out_offsets[4]); - hn::Store(out5, df, out + i + out_offsets[5]); - hn::Store(out6, df, out + i + out_offsets[6]); - hn::Store(out7, df, out + i + out_offsets[7]); - hn::Store(out8, df, out + i + out_offsets[8]); - hn::Store(out9, df, out + i + out_offsets[9]); - hn::Store(out10, df, out + i + out_offsets[10]); - hn::Store(out11, df, out + i + out_offsets[11]); - hn::Store(out12, df, out + i + out_offsets[12]); - hn::Store(out13, df, out + i + out_offsets[13]); - hn::Store(out14, df, out + i + out_offsets[14]); - hn::Store(out15, df, out + i + out_offsets[15]); - } - if HWY_LANES_CONSTEXPR (NF == 8) { - VF out0, out1, out2, out3, out4, out5, out6, out7; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out4 = hn::Load(df, out + i + out_offsets[4]); - out5 = hn::Load(df, out + i + out_offsets[5]); - out6 = hn::Load(df, out + i + out_offsets[6]); - out7 = hn::Load(df, out + i + out_offsets[7]); - Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); - VF x0 = hn::Load(df, v.Row(pos) + i); - MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - hn::Store(out4, df, out + i + out_offsets[4]); - hn::Store(out5, df, out + i + out_offsets[5]); - hn::Store(out6, df, out + i + out_offsets[6]); - hn::Store(out7, df, out + i + out_offsets[7]); - } - if HWY_LANES_CONSTEXPR (NF == 4) { - VF out0, out1, out2, out3; - out0 = hn::Load(df, out + i + out_offsets[0]); - out1 = hn::Load(df, out + i + out_offsets[1]); - out2 = hn::Load(df, out + i + out_offsets[2]); - out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - VF x0 = hn::Load(df, v.Row(pos) + i); - MulAdd4(df, x0, c0, out0, out1, out2, out3); - hn::Store(out0, df, out + i + out_offsets[0]); - hn::Store(out1, df, out + i + out_offsets[1]); - hn::Store(out2, df, out + i + out_offsets[2]); - hn::Store(out3, df, out + i + out_offsets[3]); - } - i += NF; + while (i + NF * 2 <= size) { + VF out0a, out0b; + out0a = hn::Load(df, out + i + out_offsets[0]); + VF scale0 = hn::Set(df, scales[0]); + out0a = hn::Mul(out0a, scale0); + out0b = hn::Load(df, out + i + NF + out_offsets[0]); + out0b = hn::Mul(out0b, scale0); + MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b); + hn::Store(out0a, df, out + i + out_offsets[0]); + hn::Store(out0b, df, out + i + NF + out_offsets[0]); + i += NF * 2; + v_bf += 4 * NF * NF; } HWY_DASSERT(size == i); } diff --git a/util/mat.h b/util/mat.h index 25f2cb2..e157473 100644 --- a/util/mat.h +++ b/util/mat.h @@ -202,6 +202,17 @@ class MatPtr : public IFields { override_rows_ = static_cast(rows); } + // Changes the number of rows and columns without reallocating the memory. + // Increases cols by factor and reduces rows by factor. + // The rows must be divisible by factor and the matrix must be packed. + void ReshapePackedRowsToCols(size_t factor) { + HWY_ASSERT(IsPacked()); + HWY_ASSERT(private_rows_ % factor == 0); + private_rows_ /= factor; + cols_ *= factor; + stride_ *= factor; + } + // Offset by which to advance pointers to the next row. size_t Stride() const { return stride_; } diff --git a/util/test_util.h b/util/test_util.h index 19342e4..443990f 100644 --- a/util/test_util.h +++ b/util/test_util.h @@ -106,7 +106,8 @@ template void FillMatPtrT(MatPtrT& mat) { for (int i = 0; i < mat.Rows(); ++i) { for (int j = 0; j < mat.Cols(); ++j) { - mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1); + mat.Row(i)[j] = + hwy::ConvertScalarTo(hwy::Unpredictable1() * 0.01f * (i + j + 1)); } } } diff --git a/util/zones.cc b/util/zones.cc index 6480b96..edcddfb 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) { return "FlashAttention.Inclusive"; case Zones::kFlashAttentionRmsNormAndPositionalEncoding: return "FlashAttention.RMSNormAndPositionalEncoding"; - case Zones::kFlashAttentionSingleFlashAttention: - return "FlashAttention.SingleFlashAttention"; - case Zones::kFlashAttentionTileFlashAttention: - return "FlashAttention.TileFlashAttention"; + case Zones::kFlashAttentionTileFlashAttention1: + return "FlashAttention.TileFlashAttention1"; case Zones::kFlashAttentionTileFlashAttention4: return "FlashAttention.TileFlashAttention4"; - case Zones::kFlashAttentionTransposeQ: - return "FlashAttention.TransposeQ"; + case Zones::kFlashAttentionTileFlashAttention8: + return "FlashAttention.TileFlashAttention8"; + case Zones::kFlashAttentionCombineSplit: + return "FlashAttention.CombineSplit"; case Zones::kGenActivation: return "Gen.Activation"; case Zones::kGenActivationFused: diff --git a/util/zones.h b/util/zones.h index ac96ad0..e53065e 100644 --- a/util/zones.h +++ b/util/zones.h @@ -14,10 +14,10 @@ enum class Zones { // Keep sorted kFlashAttentionFlashAttention, kFlashAttentionInclusive, kFlashAttentionRmsNormAndPositionalEncoding, - kFlashAttentionSingleFlashAttention, - kFlashAttentionTileFlashAttention, + kFlashAttentionTileFlashAttention1, kFlashAttentionTileFlashAttention4, - kFlashAttentionTransposeQ, + kFlashAttentionTileFlashAttention8, + kFlashAttentionCombineSplit, kGenActivation, kGenActivationFused, kGenAttention,