From f10ac41a20bd31974b60253eb6c6b150442a63ac Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 9 Sep 2025 08:04:45 -0700 Subject: [PATCH] Added flash attention, with both a single-q function, and a register-tiled function. The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine. PiperOrigin-RevId: 804913784 --- BUILD.bazel | 23 ++ CMakeLists.txt | 3 + gemma/activations.h | 9 +- gemma/attention.cc | 27 +- gemma/attention.h | 8 + gemma/flash_attention.cc | 510 ++++++++++++++++++++++++++++++++++ gemma/flash_attention.h | 61 ++++ gemma/flash_attention_test.cc | 171 ++++++++++++ ops/ops-inl.h | 345 +++++++++++++++++++++++ 9 files changed, 1146 insertions(+), 11 deletions(-) create mode 100644 gemma/flash_attention.cc create mode 100644 gemma/flash_attention.h create mode 100644 gemma/flash_attention_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index dbe52b7..02c54bd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -117,6 +117,27 @@ cc_library( ], ) +cc_test( + name = "flash_attention_test", + srcs = ["gemma/flash_attention_test.cc"], + deps = [ + ":configs", + ":gemma_args", + ":gemma_lib", + ":kv_cache", + ":mat", + ":matmul", + ":ops", + ":threading_context", + ":weights", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "//compression:types", + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) + cc_test( name = "threading_test", srcs = ["util/threading_test.cc"], @@ -526,12 +547,14 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/attention.cc", + "gemma/flash_attention.cc", "gemma/gemma.cc", "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", "gemma/attention.h", + "gemma/flash_attention.h", "gemma/gemma.h", "gemma/vit.h", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bc0e80..cb2911f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,8 @@ set(SOURCES gemma/attention.h gemma/configs.cc gemma/configs.h + gemma/flash_attention.cc + gemma/flash_attention.h gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc @@ -216,6 +218,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/flash_attention_test.cc gemma/tensor_info_test.cc io/blob_store_test.cc io/fields_test.cc diff --git a/gemma/activations.h b/gemma/activations.h index 71523e4..9460d15 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -56,7 +56,11 @@ struct AttentionActivations { ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), - + q_T(MatFactory("q_T", layer_config.qkv_dim, + config.vocab_size == 0 + ? batch_size * layer_config.heads * 3 + : batch_size * layer_config.heads, + allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, @@ -90,11 +94,13 @@ struct AttentionActivations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); + q_T.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); + q_T.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); @@ -105,6 +111,7 @@ struct AttentionActivations { const ModelConfig& config; MatStorageT q; // query + MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; MatStorageT att; // attention vector diff --git a/gemma/attention.cc b/gemma/attention.cc index 31ed4d1..61d76ef 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -41,12 +41,16 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "gemma/flash_attention.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use. +constexpr int kUseOldAttention = 2; + // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -71,11 +75,11 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, } } -static void PositionalEncodingQK(float* qk, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, - hwy::Profiler& p, const size_t worker, - const size_t pos, const float mul = 1.0f) { +void PositionalEncodingQK(float* qk, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + hwy::Profiler& p, const size_t worker, + const size_t pos, const float mul) { const size_t qkv_dim = layer.layer_config.qkv_dim; const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. @@ -165,8 +169,7 @@ void SingleDotSoftmaxWeightedSum( // The attention window usually starts at 0 unless `pos` is larger than // the attention window size, then it is `pos` - window_size + 1. -static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, - size_t layer_idx) { +size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { const size_t att_window_size = config.attention_window_sizes[layer_idx]; return pos - HWY_MIN(att_window_size - 1, pos); } @@ -314,7 +317,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, } PositionalEncodingQK(kv_f32, layer_idx, layer, activations, - env.ctx.profiler, worker, pos); + env.ctx.profiler, worker, pos, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); @@ -354,8 +357,12 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, (void)layer_config; // only used in HWY_DASSERT ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); - DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, - env.ctx); + if (flags & kUseOldAttention) { + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, + env.ctx); + } else { + FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); + } SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index c69cc8f..a0af4ff 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -28,6 +28,14 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ + void PositionalEncodingQK(float* qk, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + hwy::Profiler& p, size_t worker, size_t pos, \ + float mul); \ + \ + size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ + \ void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc new file mode 100644 index 0000000..40096d1 --- /dev/null +++ b/gemma/flash_attention.cc @@ -0,0 +1,510 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "util/threading_context.h" +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "gemma/activations.h" +#include "gemma/configs.h" // kMaxQKVDim +#include "gemma/gemma.h" +#include "gemma/weights.h" +#include "util/threading.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/flash_attention.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "gemma/attention.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Transposes q into q_t. +// Both are 4D tensors stuffed into a 2-D MatPtrT. +// q has shape [batch, qbatch][head, qkv_dim]. +// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum +// possible consecutive elements have the same KV. +static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, + const size_t qbatch_size, ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); + const size_t num_heads = q.Cols() / q_t.Rows(); + const size_t batch_size = q.Rows() / qbatch_size; + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + float* HWY_RESTRICT qt_row = q_t.Row(task); + for (size_t qi = 0; qi < qbatch_size; ++qi) + for (size_t h = 0; h < num_heads; ++h) { + for (size_t b = 0; b < batch_size; ++b) { + qt_row[(qi * num_heads + h) * batch_size + b] = + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task]; + } + } + }; + { + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(q_t.Rows(), ctx.pools, func); + } +} + +// Updates q in place for RMSNorm and positional encoding. +void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, + MatPtrT& q, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + ThreadingContext& ctx) { + static const auto zone = + ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); + const float query_scale = activations.query_scale; + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + for (size_t h = 0; h < layer.layer_config.heads; ++h) { + const size_t tq_idx = qbatch.Size() * task + qi; + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + task; + float* HWY_RESTRICT q_row = + q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); + }); + } + PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, + worker, pos, query_scale); + } + } + }; + { + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(num_tokens, ctx.pools, func); + } +} + +// Calculates the complete attention outputs for a single row of q. +void SingleFlashAttention(const size_t start_pos, const size_t last_pos, + const float* HWY_RESTRICT q, const MatPtrT& k, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + float* HWY_RESTRICT att_out, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention"); + PROFILER_ZONE3(p, worker, zone); + const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); + float m = Dot(q, k.Row(pos_mod), k.Cols()); + float d = 1.0f; + // This is just a copy of the first token. + MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + const size_t pos_mod = activations.div_seq_len.Remainder(pos); + float x = Dot(q, k.Row(pos_mod), k.Cols()); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + x = activations.config.att_cap * + std::tanh(x / activations.config.att_cap); + } + float m_new = std::max(m, x); + float scale = d * std::exp(m - m_new); + x = std::exp(x - m_new); + m = m_new; + d = scale + x; + float one_over_d = 1.0f / d; + x *= one_over_d; + scale *= one_over_d; + MulByConst(scale, att_out, v.Cols(), p, worker); + MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker); + } +} + +// Computes and returns a single vector of NF Q.K dot products, which represents +// the dot products of NF rows of Q for a single K timestep. +template > +VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, + const size_t k_pos, const MatPtrT& q, + const MatPtrT& k, hwy::Profiler& p, const size_t worker) { + hn::TFromD results[hn::MaxLanes(df)]; + for (size_t i = 0; i < hn::Lanes(df); ++i) { + results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); + } + return hn::LoadU(df, results); +} + +// Returns an 8xNF tile of Q.K dot products, in single precision. +// This is the result of NF rows of Q against 8 K timesteps, with positions +// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in +// consecutive elements, and other columns by adding q_stride. +template > +void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const size_t* k_pos, + hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, + VF& sum7) { + constexpr size_t kHTileSize = 8; + sum0 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + sum4 = hn::Zero(df); + sum5 = hn::Zero(df); + sum6 = hn::Zero(df); + sum7 = hn::Zero(df); + const float* HWY_RESTRICT k_row[kHTileSize]; + for (int i = 0; i < kHTileSize; ++i) { + k_row[i] = k.Row(k_pos[i]); + } + for (size_t i = 0; i < k.Cols(); ++i) { + VF q_vec = hn::Load(df, q); + VF k_0 = hn::Set(df, k_row[0][i]); + sum0 = hn::MulAdd(q_vec, k_0, sum0); + VF k_1 = hn::Set(df, k_row[1][i]); + sum1 = hn::MulAdd(q_vec, k_1, sum1); + VF k_2 = hn::Set(df, k_row[2][i]); + sum2 = hn::MulAdd(q_vec, k_2, sum2); + VF k_3 = hn::Set(df, k_row[3][i]); + sum3 = hn::MulAdd(q_vec, k_3, sum3); + VF k_4 = hn::Set(df, k_row[4][i]); + sum4 = hn::MulAdd(q_vec, k_4, sum4); + VF k_5 = hn::Set(df, k_row[5][i]); + sum5 = hn::MulAdd(q_vec, k_5, sum5); + VF k_6 = hn::Set(df, k_row[6][i]); + sum6 = hn::MulAdd(q_vec, k_6, sum6); + VF k_7 = hn::Set(df, k_row[7][i]); + sum7 = hn::MulAdd(q_vec, k_7, sum7); + q += q_stride; + } +} + +// Returns the element-wise maximum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF m0 = hn::Max(x0, x1); + VF m1 = hn::Max(x2, x3); + VF m2 = hn::Max(x4, x5); + VF m3 = hn::Max(x6, x7); + m0 = hn::Max(m0, m1); + m2 = hn::Max(m2, m3); + return hn::Max(m0, m2); +} + +// Returns the element-wise sum of 8 vectors, in a single vector. +template > +VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, + const VF& x3, const VF& x4, const VF& x5, + const VF& x6, const VF& x7) { + VF sum0 = hn::Add(x0, x1); + VF sum1 = hn::Add(x2, x3); + VF sum2 = hn::Add(x4, x5); + VF sum3 = hn::Add(x6, x7); + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + return hn::Add(sum0, sum2); +} + +// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then +// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos]. +void TileFlashAttention( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, const AttentionActivations& activations, + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); + PROFILER_ZONE3(p, worker, zone); + constexpr int kHTileSize = 8; + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + VI lasts = hn::LoadU(di, last_pos); + VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); + VF old_d = hn::Zero(df); + const float* HWY_RESTRICT qT_row = qT.Row(0); + const size_t qT_stride = qT.Stride(); + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + size_t k_pos[kHTileSize]; + for (size_t i = 0; i < kHTileSize; ++i) { + k_pos[i] = activations.div_seq_len.Remainder(position + i); + } + VF x0, x1, x2, x3, x4, x5, x6, x7; + QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3, + x4, x5, x6, x7); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap))); + x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap))); + x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap))); + x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap))); + } + VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + m = hn::Max(old_m, m); + x0 = hn::Exp(df, x0 - m); + x1 = hn::Exp(df, x1 - m); + x2 = hn::Exp(df, x2 - m); + x3 = hn::Exp(df, x3 - m); + x4 = hn::Exp(df, x4 - m); + x5 = hn::Exp(df, x5 - m); + x6 = hn::Exp(df, x6 - m); + x7 = hn::Exp(df, x7 - m); + VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); + old_d = hn::Add(scale, old_d); + old_m = m; + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + scale = hn::Mul(scale, one_over_d); + x0 = hn::Mul(x0, one_over_d); + x1 = hn::Mul(x1, one_over_d); + x2 = hn::Mul(x2, one_over_d); + x3 = hn::Mul(x3, one_over_d); + x4 = hn::Mul(x4, one_over_d); + x5 = hn::Mul(x5, one_over_d); + x6 = hn::Mul(x6, one_over_d); + x7 = hn::Mul(x7, one_over_d); + MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, + att_out.Row(0), out_offsets, v.Cols(), p, worker); + position += kHTileSize; + } + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + } + // Past the last position, x0 doesn't count. + auto mask = hn::Gt(hn::Set(di, position), lasts); + VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask), + std::numeric_limits::max() / 2.0f); + x0 = hn::Sub(x0, causal_offset); + VF m = hn::Max(old_m, x0); + x0 = hn::Exp(df, x0 - m); + VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + old_m = m; + old_d = hn::Add(scale, x0); + VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); + x0 = hn::Mul(x0, one_over_d); + scale = hn::Mul(scale, one_over_d); + MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, + v.Cols(), p, worker); + ++position; + } +} + +// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] +// into a single output O[L,D]. +// Conventional attention first computes A[L,L] = Q . KT +// followed by A = softmax(A) (over invididual rows). +// Then A is multiplied by V to get O[L,D]. +// For each row of O, this takes a read of one row of Q L times, all of K, +// 3 write/reads of one row of A, read all of V, an read.write the one row of O +// L times. Ignoring the computation for now, and focusing just on memory, +// the one row of O takes L(4D+3) reads and L(D+3) writes. +// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. +// +// Flash attention fuses these operations together, and (where possible) +// computes NF rows of the result using 8 accumulator registers and two more to +// keep running results. NF is the number of float lanes in a register, being 16 +// for AVX3. The softmax is converted to streaming form using the +// algortihm from: +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. +// Q is transposed to Q_T[D,L] to make the dot product computation efficient. +// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing +// reads of Q by 8 and reads of K by NF. The streaming softmax is computed +// entirely in registers, and a further NF registers to accumulate the results +// of the product of the softmax and V, reduce the number of reads of V by NF, +// and the reads/writes of O by 8. +// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, +// which on AVX3 is an overall reduction by about a factor of 10. +// +// A further complication is that real attention is not as simple as documented +// in the paper and above. There are multiple query heads, differing KV, and +// different sequence lengths, so a lot of the work in FlashAttention is making +// sure that a collection of q rows can use the TileFlashAttention path. +void FlashAttention(const size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + AttentionActivations& activations, QBatch& qbatch, + ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); + RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, + layer, activations, ctx); + const hwy::Divisor div_qbatch(qbatch.Size()); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + + using DF = hn::ScalableTag; + const DF df; + constexpr size_t kVTileSize = hn::MaxLanes(df); + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); + const size_t total_tasks = token_batch * layer_config.heads; + // q has shape [batch, qbatch][head, qkv_dim]. + // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the + // maximum possible number of consecutive columns have the same KV matrices. + // Each thread will process a tile of NF columns of QT so the starting column + // index of QT is just the task index * kVTileSize. + TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx); + const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize); + const hwy::Divisor div_tokens(num_tokens); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + + // For each head/token/query, compute fused flash Q.K, softmax and weighted V. + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + PROFILER_ZONE3(ctx.profiler, worker, zone); + // Offsets into original Q for each row in the tile. + uint32_t q_offsets[kVTileSize]; + // Offsets into att_out for each row in the tile. + uint32_t out_offsets[kVTileSize]; + // Start positions for each row in the tile. + size_t start_positions[kVTileSize]; + // Last positions for each row in the tile. Inclusive. + uint32_t last_pos[kVTileSize]; + // min and max last positions across all rows in the tile determines when + // TileFlashAttention switches to single vector mode to handle the + // ragged sequence lengths. + size_t min_last_pos = std::numeric_limits::max(); + size_t max_last_pos = 0; + // Indices into the qbatch.KV for each row in the tile. + size_t qi_indices[kVTileSize]; + // Indices into the kv_cache for each row in the tile. + size_t kv_offsets[kVTileSize]; + // first_task is [qbatch, head, token]. + const size_t first_task = task * kVTileSize; + const size_t last_task = first_task + kVTileSize - 1; + bool use_tile_attention = last_task < total_tasks; + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + const size_t batch_idx = div_tokens.Remainder(first_task + offset); + const size_t qh = div_tokens.Divide(first_task + offset); + const size_t head = activations.div_heads.Remainder(qh); + const size_t qi = activations.div_heads.Divide(qh); + const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi; + qi_indices[offset] = qi; + + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + start_positions[offset] = start_pos; + size_t last = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last) { + // last_pos in QDotK and WeightedSumV is inclusive. + last = prefix_end - 1; + } + last_pos[offset] = last; + min_last_pos = HWY_MIN(min_last_pos, last); + max_last_pos = HWY_MAX(max_last_pos, last); + q_offsets[offset] = + activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); + out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - + activations.att_out.Row(0); + const size_t kv_index = head / kHeadGroups; + const size_t head_offset = kv_index * qkv_dim * 2; + kv_offsets[offset] = layer_idx * cache_layer_size + head_offset; + // If any of the parameters in this if statement differ within this task, + // then we can't use TileFlashAttention. TileFlashAttention requires that + // all rows in the tile have the same K and V matrices, and Q starts at + // the same position. The end positions do not have to be the equal. + if (start_positions[offset] != start_positions[0] || + qi_indices[offset] != qi_indices[0] || + kv_offsets[offset] != kv_offsets[0]) { + use_tile_attention = false; + } + } + for (size_t offset = 0; + offset < kVTileSize && first_task + offset < total_tasks; ++offset) { + auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim, + kv_cache.Stride()); + if (use_tile_attention) { + // To avoid duplicating the code to setup K and V, the call to + // TileFlashAttention is inside the loop over tasks, even thought it + // handles all rows in the task at once. + StridedView qT = + StridedView(activations.q_T.Row(0) + first_task, kVTileSize, + activations.q_T.Stride()); + TileFlashAttention( + activations.q, q_offsets, qT, k, start_positions[offset], last_pos, + min_last_pos, max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, worker); + break; + } else { + SingleFlashAttention(start_positions[offset], last_pos[offset], + activations.q.Row(0) + q_offsets[offset], k, v, + layer_idx, layer, activations, + activations.att_out.Row(0) + out_offsets[offset], + ctx.profiler, worker); + } + } + }; + + { + PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h new file mode 100644 index 0000000..b505d6f --- /dev/null +++ b/gemma/flash_attention.h @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ + +// Declares FlashAttention for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ + MatPtrT& q, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const float* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + float* HWY_RESTRICT att_out, hwy::Profiler& p, \ + size_t worker); \ + \ + void FlashAttention(size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION) + +#undef GEMMA_DECL_FLASH_ATTENTION + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc new file mode 100644 index 0000000..efb210e --- /dev/null +++ b/gemma/flash_attention_test.cc @@ -0,0 +1,171 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "compression/types.h" +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "gemma/weights.h" +#include "ops/matmul.h" +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include +#include + +#include // std::max +#include // std::abs +#include + +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/flash_attention_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "gemma/attention.h" +#include "gemma/configs.h" +#include "gemma/flash_attention.h" +#include "ops/matvec-inl.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +using FloatPtr = hwy::AlignedFreeUniquePtr; + +void SetMat(const size_t offset, MatPtrT& mat) { + const size_t kOuter = mat.Extents().rows; + const size_t kInner = mat.Extents().cols; + const float i_scale = 1.0f / kInner; + const float j_scale = 1.0f / kOuter; + for (size_t i = 0; i < kOuter; ++i) { + float* row = mat.Row(i); + for (size_t j = 0; j < kInner; ++j) { + row[j] = + static_cast((i * kInner * i_scale + (j + offset) * j_scale)); + } + } +} + +std::unique_ptr> MakeCopyOfMat(const MatPtrT& mat, + const Allocator& allocator) { + auto copy = std::make_unique>("TestMat", mat.Extents(), + allocator, MatPadding::kOdd); + CopyMat(mat, *copy); + return copy; +} + +void AssertClose(const MatPtrT& a, const MatPtrT& b) { + // Avoid comparing the padding bytes, which are uninitialized. + for (size_t r = 0; r < a.Rows(); ++r) { + const float* HWY_RESTRICT a_row = a.Row(r); + const float* HWY_RESTRICT b_row = b.Row(r); + for (size_t c = 0; c < a.Cols(); ++c) { + float rel_abs_delta = std::abs(a_row[c] - b_row[c]); + if (rel_abs_delta > 0.0f) { + rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); + } + EXPECT_LT(rel_abs_delta, 1e-5) + << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," + << c << "]=" << b_row[c]; + } + } +} + +void TestAttention() { + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + // hwy::ThreadPool& pool = ctx.pools.Pool(); + constexpr size_t kOuter = 1024; + constexpr size_t kInner = 256; + ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); + TensorInfoRegistry tensor_info_registry(config); + const LayerConfig& layer_config = config.layer_configs[0]; + const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); + InferenceArgs inference_args; + RuntimeConfig runtime_config; + KVCache kv_cache(config, inference_args, ctx.allocator); + MatMulEnv env(ctx); + Activations activations(config, runtime_config.prefill_tbatch_size, + kv_cache.SeqLen(), env.ctx, env.row_ptrs); + std::vector tokens(kOuter); + std::iota(tokens.begin(), tokens.end(), 1); + PromptTokens prompt(tokens); + AllQueries all_queries(hwy::Span(&prompt, 1), + hwy::Span(&kv_cache, 1)); + QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); + const size_t batch_size = kOuter; + std::vector> row_ptrs; + AttentionActivations attention(config, layer_config, batch_size, kOuter, + ctx.allocator, row_ptrs); + const size_t qkv_dim = layer_config.qkv_dim; + ASSERT_EQ(qkv_dim, kInner); + const hwy::Divisor div_qbatch(qbatch.Size()); + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + const size_t seq_len = + static_cast(attention.div_seq_len.GetDivisor()); + auto& kvc = qbatch.KV(0).kv_cache; + for (size_t h = 0; h < layer_config.heads; ++h) { + // Make strided views into the kv cache for + // this query and head. + const size_t head_offset = (h / kHeadGroups) * qkv_dim * 2; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kvc.Row(0) + head_offset, kvc.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride()); + SetMat(h + layer_config.heads, k); + SetMat(h + layer_config.heads * 2, v); + } + SetMat(1, attention.q); + DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); + // Copy the output to saved_att to allow for comparison. + auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); + SetMat(1, attention.q); + FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx); + AssertClose(attention.att_out, *saved_att); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(FlashAttentionTest); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0c6bd50..cfd85ae 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -613,6 +613,351 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( }); } +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); + sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); + sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); + sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); + sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8); + sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9); + sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10); + sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11); + sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12); + sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13); + sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14); + sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); + sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); + sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); + sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, + VF& sum3) { + sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); + sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); + sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); + sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); +} + +// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows +// of V by the corresponding values in c0-c7 and adds them to NF rows of out, +// after first prescaling out by scale. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( + DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, + const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df); + + size_t i = 0; + while (i + NF <= size) { + if HWY_LANES_CONSTEXPR (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); + out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); + out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); + out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); + out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); + out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); + out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); + out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } + if HWY_LANES_CONSTEXPR (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } + if HWY_LANES_CONSTEXPR (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + VF x0 = hn::Load(df, v.Row(pos[0]) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + VF x1 = hn::Load(df, v.Row(pos[1]) + i); + MulAdd4(df, x1, c1, out0, out1, out2, out3); + VF x2 = hn::Load(df, v.Row(pos[2]) + i); + MulAdd4(df, x2, c2, out0, out1, out2, out3); + VF x3 = hn::Load(df, v.Row(pos[3]) + i); + MulAdd4(df, x3, c3, out0, out1, out2, out3); + VF x4 = hn::Load(df, v.Row(pos[4]) + i); + MulAdd4(df, x4, c4, out0, out1, out2, out3); + VF x5 = hn::Load(df, v.Row(pos[5]) + i); + MulAdd4(df, x5, c5, out0, out1, out2, out3); + VF x6 = hn::Load(df, v.Row(pos[6]) + i); + MulAdd4(df, x6, c6, out0, out1, out2, out3); + VF x7 = hn::Load(df, v.Row(pos[7]) + i); + MulAdd4(df, x7, c7, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } + i += NF; + } + const size_t remaining = size - i; + HWY_DASSERT(remaining == 0); +} + +// Prescales NF rows of out by scale, then multiplies 1 row of V by the +// corresponding values in c0 and adds them to the NF rows of out. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( + DF df, const VF scale, const VF c0, const MatPtrT& v, + const size_t pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + const size_t NF = hn::MaxLanes(df); + + size_t i = 0; + while (i + NF <= size) { + if constexpr (NF == 16) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + VF out8, out9, out10, out11, out12, out13, out14, out15; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out8 = hn::Load(df, out + i + out_offsets[8]); + out9 = hn::Load(df, out + i + out_offsets[9]); + out10 = hn::Load(df, out + i + out_offsets[10]); + out11 = hn::Load(df, out + i + out_offsets[11]); + out12 = hn::Load(df, out + i + out_offsets[12]); + out13 = hn::Load(df, out + i + out_offsets[13]); + out14 = hn::Load(df, out + i + out_offsets[14]); + out15 = hn::Load(df, out + i + out_offsets[15]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); + out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); + out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); + out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); + out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); + out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); + out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); + out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + hn::Store(out8, df, out + i + out_offsets[8]); + hn::Store(out9, df, out + i + out_offsets[9]); + hn::Store(out10, df, out + i + out_offsets[10]); + hn::Store(out11, df, out + i + out_offsets[11]); + hn::Store(out12, df, out + i + out_offsets[12]); + hn::Store(out13, df, out + i + out_offsets[13]); + hn::Store(out14, df, out + i + out_offsets[14]); + hn::Store(out15, df, out + i + out_offsets[15]); + } else if constexpr (NF == 8) { + VF out0, out1, out2, out3, out4, out5, out6, out7; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out4 = hn::Load(df, out + i + out_offsets[4]); + out5 = hn::Load(df, out + i + out_offsets[5]); + out6 = hn::Load(df, out + i + out_offsets[6]); + out7 = hn::Load(df, out + i + out_offsets[7]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); + out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); + out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); + out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + hn::Store(out4, df, out + i + out_offsets[4]); + hn::Store(out5, df, out + i + out_offsets[5]); + hn::Store(out6, df, out + i + out_offsets[6]); + hn::Store(out7, df, out + i + out_offsets[7]); + } else if constexpr (NF == 4) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); + out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); + out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); + out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + VF x0 = hn::Load(df, v.Row(pos) + i); + MulAdd4(df, x0, c0, out0, out1, out2, out3); + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + } else { + HWY_DASSERT(false); + } + i += NF; + } + const size_t remaining = size - i; + HWY_DASSERT(remaining == 0); +} + // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,