Added flash attention, with both a single-q function, and a register-tiled function.

The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine.

PiperOrigin-RevId: 804913784
This commit is contained in:
Ray Smith 2025-09-09 08:04:45 -07:00 committed by Copybara-Service
parent 24b1760f03
commit f10ac41a20
9 changed files with 1146 additions and 11 deletions

View File

@ -117,6 +117,27 @@ cc_library(
], ],
) )
cc_test(
name = "flash_attention_test",
srcs = ["gemma/flash_attention_test.cc"],
deps = [
":configs",
":gemma_args",
":gemma_lib",
":kv_cache",
":mat",
":matmul",
":ops",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
cc_test( cc_test(
name = "threading_test", name = "threading_test",
srcs = ["util/threading_test.cc"], srcs = ["util/threading_test.cc"],
@ -526,12 +547,14 @@ cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
"gemma/attention.cc", "gemma/attention.cc",
"gemma/flash_attention.cc",
"gemma/gemma.cc", "gemma/gemma.cc",
"gemma/vit.cc", "gemma/vit.cc",
], ],
hdrs = [ hdrs = [
"gemma/activations.h", "gemma/activations.h",
"gemma/attention.h", "gemma/attention.h",
"gemma/flash_attention.h",
"gemma/gemma.h", "gemma/gemma.h",
"gemma/vit.h", "gemma/vit.h",
], ],

View File

@ -79,6 +79,8 @@ set(SOURCES
gemma/attention.h gemma/attention.h
gemma/configs.cc gemma/configs.cc
gemma/configs.h gemma/configs.h
gemma/flash_attention.cc
gemma/flash_attention.h
gemma/gemma_args.h gemma/gemma_args.h
gemma/gemma-inl.h gemma/gemma-inl.h
gemma/gemma.cc gemma/gemma.cc
@ -216,6 +218,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc compression/nuq_test.cc
compression/sfp_test.cc compression/sfp_test.cc
evals/gemma_test.cc evals/gemma_test.cc
gemma/flash_attention_test.cc
gemma/tensor_info_test.cc gemma/tensor_info_test.cc
io/blob_store_test.cc io/blob_store_test.cc
io/fields_test.cc io/fields_test.cc

View File

@ -56,7 +56,11 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim, : layer_config.heads * layer_config.qkv_dim,
allocator)), allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0
? batch_size * layer_config.heads * 3
: batch_size * layer_config.heads,
allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)), config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len, att(MatFactory("att", batch_size, layer_config.heads * seq_len,
@ -90,11 +94,13 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
} }
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size); q.OverrideRows(batch_size);
q_T.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
@ -105,6 +111,7 @@ struct AttentionActivations {
const ModelConfig& config; const ModelConfig& config;
MatStorageT<float> q; // query MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out; MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector MatStorageT<float> att; // attention vector

View File

@ -41,12 +41,16 @@
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/flash_attention.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use.
constexpr int kUseOldAttention = 2;
// Computes Q.K scores, which are "logits" (or scores) stored to att. // Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
@ -71,11 +75,11 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
} }
} }
static void PositionalEncodingQK(float* qk, const size_t layer_idx, void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
hwy::Profiler& p, const size_t worker, hwy::Profiler& p, const size_t worker,
const size_t pos, const float mul = 1.0f) { const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk; const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
@ -165,8 +169,7 @@ void SingleDotSoftmaxWeightedSum(
// The attention window usually starts at 0 unless `pos` is larger than // The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1. // the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
size_t layer_idx) {
const size_t att_window_size = config.attention_window_sizes[layer_idx]; const size_t att_window_size = config.attention_window_sizes[layer_idx];
return pos - HWY_MIN(att_window_size - 1, pos); return pos - HWY_MIN(att_window_size - 1, pos);
} }
@ -314,7 +317,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
env.ctx.profiler, worker, pos); env.ctx.profiler, worker, pos, /*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });
@ -354,8 +357,12 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
(void)layer_config; // only used in HWY_DASSERT (void)layer_config; // only used in HWY_DASSERT
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, if (flags & kUseOldAttention) {
env.ctx); DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
} else {
FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx);
}
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -28,6 +28,14 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
hwy::Profiler& p, size_t worker, size_t pos, \
float mul); \
\
size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \
\
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \

510
gemma/flash_attention.cc Normal file
View File

@ -0,0 +1,510 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/flash_attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Transposes q into q_t.
// Both are 4D tensors stuffed into a 2-D MatPtrT.
// q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
const size_t num_heads = q.Cols() / q_t.Rows();
const size_t batch_size = q.Rows() / qbatch_size;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
float* HWY_RESTRICT qt_row = q_t.Row(task);
for (size_t qi = 0; qi < qbatch_size; ++qi)
for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task];
}
}
};
{
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(q_t.Rows(), ctx.pools, func);
}
}
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<KV_t>& q, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
ThreadingContext& ctx) {
static const auto zone =
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
const float query_scale = activations.query_scale;
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * task + qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + task;
float* HWY_RESTRICT q_row =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
}
}
};
{
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_tokens, ctx.pools, func);
}
}
// Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att_out, hwy::Profiler& p,
const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention");
PROFILER_ZONE3(p, worker, zone);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
float d = 1.0f;
// This is just a copy of the first token.
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
x = activations.config.att_cap *
std::tanh(x / activations.config.att_cap);
}
float m_new = std::max(m, x);
float scale = d * std::exp(m - m_new);
x = std::exp(x - m_new);
m = m_new;
d = scale + x;
float one_over_d = 1.0f / d;
x *= one_over_d;
scale *= one_over_d;
MulByConst(scale, att_out, v.Cols(), p, worker);
MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker);
}
}
// Computes and returns a single vector of NF Q.K dot products, which represents
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q,
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) {
hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
}
return hn::LoadU(df, results);
}
// Returns an 8xNF tile of Q.K dot products, in single precision.
// This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
constexpr size_t kHTileSize = 8;
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
sum3 = hn::Zero(df);
sum4 = hn::Zero(df);
sum5 = hn::Zero(df);
sum6 = hn::Zero(df);
sum7 = hn::Zero(df);
const float* HWY_RESTRICT k_row[kHTileSize];
for (int i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]);
}
for (size_t i = 0; i < k.Cols(); ++i) {
VF q_vec = hn::Load(df, q);
VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]);
sum1 = hn::MulAdd(q_vec, k_1, sum1);
VF k_2 = hn::Set(df, k_row[2][i]);
sum2 = hn::MulAdd(q_vec, k_2, sum2);
VF k_3 = hn::Set(df, k_row[3][i]);
sum3 = hn::MulAdd(q_vec, k_3, sum3);
VF k_4 = hn::Set(df, k_row[4][i]);
sum4 = hn::MulAdd(q_vec, k_4, sum4);
VF k_5 = hn::Set(df, k_row[5][i]);
sum5 = hn::MulAdd(q_vec, k_5, sum5);
VF k_6 = hn::Set(df, k_row[6][i]);
sum6 = hn::MulAdd(q_vec, k_6, sum6);
VF k_7 = hn::Set(df, k_row[7][i]);
sum7 = hn::MulAdd(q_vec, k_7, sum7);
q += q_stride;
}
}
// Returns the element-wise maximum of 8 vectors, in a single vector.
template <class DF, class VF = hn::Vec<DF>>
VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
const VF& x3, const VF& x4, const VF& x5,
const VF& x6, const VF& x7) {
VF m0 = hn::Max(x0, x1);
VF m1 = hn::Max(x2, x3);
VF m2 = hn::Max(x4, x5);
VF m3 = hn::Max(x6, x7);
m0 = hn::Max(m0, m1);
m2 = hn::Max(m2, m3);
return hn::Max(m0, m2);
}
// Returns the element-wise sum of 8 vectors, in a single vector.
template <class DF, class VF = hn::Vec<DF>>
VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
const VF& x3, const VF& x4, const VF& x5,
const VF& x6, const VF& x7) {
VF sum0 = hn::Add(x0, x1);
VF sum1 = hn::Add(x2, x3);
VF sum2 = hn::Add(x4, x5);
VF sum3 = hn::Add(x6, x7);
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
return hn::Add(sum0, sum2);
}
// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then
// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos].
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention");
PROFILER_ZONE3(p, worker, zone);
constexpr int kHTileSize = 8;
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
using DI = hn::ScalableTag<uint32_t>;
const DI di;
using VI = hn::Vec<DI>;
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
size_t k_pos[kHTileSize];
for (size_t i = 0; i < kHTileSize; ++i) {
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
x4, x5, x6, x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
x4 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x4, one_over_cap)));
x5 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x5, one_over_cap)));
x6 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x6, one_over_cap)));
x7 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x7, one_over_cap)));
}
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
m = hn::Max(old_m, m);
x0 = hn::Exp(df, x0 - m);
x1 = hn::Exp(df, x1 - m);
x2 = hn::Exp(df, x2 - m);
x3 = hn::Exp(df, x3 - m);
x4 = hn::Exp(df, x4 - m);
x5 = hn::Exp(df, x5 - m);
x6 = hn::Exp(df, x6 - m);
x7 = hn::Exp(df, x7 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
old_d = hn::Add(scale, old_d);
old_m = m;
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
scale = hn::Mul(scale, one_over_d);
x0 = hn::Mul(x0, one_over_d);
x1 = hn::Mul(x1, one_over_d);
x2 = hn::Mul(x2, one_over_d);
x3 = hn::Mul(x3, one_over_d);
x4 = hn::Mul(x4, one_over_d);
x5 = hn::Mul(x5, one_over_d);
x6 = hn::Mul(x6, one_over_d);
x7 = hn::Mul(x7, one_over_d);
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
att_out.Row(0), out_offsets, v.Cols(), p, worker);
position += kHTileSize;
}
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
VF cap = hn::Set(df, activations.config.att_cap);
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
}
// Past the last position, x0 doesn't count.
auto mask = hn::Gt(hn::Set(di, position), lasts);
VF causal_offset = hn::MaskedSet(df, RebindMask(df, mask),
std::numeric_limits<float>::max() / 2.0f);
x0 = hn::Sub(x0, causal_offset);
VF m = hn::Max(old_m, x0);
x0 = hn::Exp(df, x0 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
old_m = m;
old_d = hn::Add(scale, x0);
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
x0 = hn::Mul(x0, one_over_d);
scale = hn::Mul(scale, one_over_d);
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
v.Cols(), p, worker);
++position;
}
}
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
// into a single output O[L,D].
// Conventional attention first computes A[L,L] = Q . KT
// followed by A = softmax(A) (over invididual rows).
// Then A is multiplied by V to get O[L,D].
// For each row of O, this takes a read of one row of Q L times, all of K,
// 3 write/reads of one row of A, read all of V, an read.write the one row of O
// L times. Ignoring the computation for now, and focusing just on memory,
// the one row of O takes L(4D+3) reads and L(D+3) writes.
// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes.
//
// Flash attention fuses these operations together, and (where possible)
// computes NF rows of the result using 8 accumulator registers and two more to
// keep running results. NF is the number of float lanes in a register, being 16
// for AVX3. The softmax is converted to streaming form using the
// algortihm from:
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf.
// Q is transposed to Q_T[D,L] to make the dot product computation efficient.
// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing
// reads of Q by 8 and reads of K by NF. The streaming softmax is computed
// entirely in registers, and a further NF registers to accumulate the results
// of the product of the softmax and V, reduce the number of reads of V by NF,
// and the reads/writes of O by 8.
// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8,
// which on AVX3 is an overall reduction by about a factor of 10.
//
// A further complication is that real attention is not as simple as documented
// in the paper and above. There are multiple query heads, differing KV, and
// different sequence lengths, so a lot of the work in FlashAttention is making
// sure that a collection of q rows can use the TileFlashAttention path.
void FlashAttention(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention");
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
using DF = hn::ScalableTag<float>;
const DF df;
constexpr size_t kVTileSize = hn::MaxLanes(df);
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
const size_t total_tasks = token_batch * layer_config.heads;
// q has shape [batch, qbatch][head, qkv_dim].
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
// maximum possible number of consecutive columns have the same KV matrices.
// Each thread will process a tile of NF columns of QT so the starting column
// index of QT is just the task index * kVTileSize.
TransposeQ(activations.q, activations.q_T, qbatch.Size(), ctx);
const size_t num_thread_tasks = hwy::DivCeil(total_tasks, kVTileSize);
const hwy::Divisor div_tokens(num_tokens);
// All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
// For each head/token/query, compute fused flash Q.K, softmax and weighted V.
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[kVTileSize];
// Offsets into att_out for each row in the tile.
uint32_t out_offsets[kVTileSize];
// Start positions for each row in the tile.
size_t start_positions[kVTileSize];
// Last positions for each row in the tile. Inclusive.
uint32_t last_pos[kVTileSize];
// min and max last positions across all rows in the tile determines when
// TileFlashAttention switches to single vector mode to handle the
// ragged sequence lengths.
size_t min_last_pos = std::numeric_limits<size_t>::max();
size_t max_last_pos = 0;
// Indices into the qbatch.KV for each row in the tile.
size_t qi_indices[kVTileSize];
// Indices into the kv_cache for each row in the tile.
size_t kv_offsets[kVTileSize];
// first_task is [qbatch, head, token].
const size_t first_task = task * kVTileSize;
const size_t last_task = first_task + kVTileSize - 1;
bool use_tile_attention = last_task < total_tasks;
for (size_t offset = 0;
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
const size_t batch_idx = div_tokens.Remainder(first_task + offset);
const size_t qh = div_tokens.Divide(first_task + offset);
const size_t head = activations.div_heads.Remainder(qh);
const size_t qi = activations.div_heads.Divide(qh);
const size_t tq_idx = div_qbatch.GetDivisor() * batch_idx + qi;
qi_indices[offset] = qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
start_positions[offset] = start_pos;
size_t last = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last) {
// last_pos in QDotK and WeightedSumV is inclusive.
last = prefix_end - 1;
}
last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] =
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups;
const size_t head_offset = kv_index * qkv_dim * 2;
kv_offsets[offset] = layer_idx * cache_layer_size + head_offset;
// If any of the parameters in this if statement differ within this task,
// then we can't use TileFlashAttention. TileFlashAttention requires that
// all rows in the tile have the same K and V matrices, and Q starts at
// the same position. The end positions do not have to be the equal.
if (start_positions[offset] != start_positions[0] ||
qi_indices[offset] != qi_indices[0] ||
kv_offsets[offset] != kv_offsets[0]) {
use_tile_attention = false;
}
}
for (size_t offset = 0;
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
auto& kv_cache = qbatch.KV(qi_indices[offset]).kv_cache;
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_offsets[offset], kv_cache.Stride());
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_offsets[offset] + qkv_dim,
kv_cache.Stride());
if (use_tile_attention) {
// To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even thought it
// handles all rows in the task at once.
StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
TileFlashAttention(
activations.q, q_offsets, qT, k, start_positions[offset], last_pos,
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
activations.att_out, out_offsets, ctx.profiler, worker);
break;
} else {
SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx.profiler, worker);
}
}
};
{
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

61
gemma/flash_attention.h Normal file
View File

@ -0,0 +1,61 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
// Declares FlashAttention for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
MatPtrT<KV_t>& q, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
size_t worker); \
\
void FlashAttention(size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
#undef GEMMA_DECL_FLASH_ATTENTION
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_

View File

@ -0,0 +1,171 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <numeric>
#include <vector>
#include "compression/types.h"
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::max
#include <cmath> // std::abs
#include <memory>
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/flash_attention_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "gemma/configs.h"
#include "gemma/flash_attention.h"
#include "ops/matvec-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
void SetMat(const size_t offset, MatPtrT<float>& mat) {
const size_t kOuter = mat.Extents().rows;
const size_t kInner = mat.Extents().cols;
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
float* row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
}
}
}
std::unique_ptr<MatStorageT<float>> MakeCopyOfMat(const MatPtrT<float>& mat,
const Allocator& allocator) {
auto copy = std::make_unique<MatStorageT<float>>("TestMat", mat.Extents(),
allocator, MatPadding::kOdd);
CopyMat(mat, *copy);
return copy;
}
void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
// Avoid comparing the padding bytes, which are uninitialized.
for (size_t r = 0; r < a.Rows(); ++r) {
const float* HWY_RESTRICT a_row = a.Row(r);
const float* HWY_RESTRICT b_row = b.Row(r);
for (size_t c = 0; c < a.Cols(); ++c) {
float rel_abs_delta = std::abs(a_row[c] - b_row[c]);
if (rel_abs_delta > 0.0f) {
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
}
EXPECT_LT(rel_abs_delta, 1e-5)
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
<< c << "]=" << b_row[c];
}
}
}
void TestAttention() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
// hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 1024;
constexpr size_t kInner = 256;
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
TensorInfoRegistry tensor_info_registry(config);
const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
InferenceArgs inference_args;
RuntimeConfig runtime_config;
KVCache kv_cache(config, inference_args, ctx.allocator);
MatMulEnv env(ctx);
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
std::vector<int> tokens(kOuter);
std::iota(tokens.begin(), tokens.end(), 1);
PromptTokens prompt(tokens);
AllQueries all_queries(hwy::Span<const PromptTokens>(&prompt, 1),
hwy::Span<KVCache>(&kv_cache, 1));
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention(config, layer_config, batch_size, kOuter,
ctx.allocator, row_ptrs);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);
const hwy::Divisor div_qbatch(qbatch.Size());
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
auto& kvc = qbatch.KV(0).kv_cache;
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
const size_t head_offset = (h / kHeadGroups) * qkv_dim * 2;
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kvc.Row(0) + head_offset, kvc.Stride());
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
SetMat(h + layer_config.heads, k);
SetMat(h + layer_config.heads * 2, v);
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
// Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q);
FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx);
AssertClose(attention.att_out, *saved_att);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(FlashAttentionTest);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -613,6 +613,351 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
}); });
} }
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4);
sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5);
sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6);
sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7);
sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8);
sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9);
sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10);
sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11);
sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12);
sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13);
sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14);
sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4);
sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5);
sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6);
sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2,
VF& sum3) {
sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0);
sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1);
sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2);
sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3);
}
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
out8 = hn::Mul(out8, hn::Set(df, scale.raw[8]));
out9 = hn::Mul(out9, hn::Set(df, scale.raw[9]));
out10 = hn::Mul(out10, hn::Set(df, scale.raw[10]));
out11 = hn::Mul(out11, hn::Set(df, scale.raw[11]));
out12 = hn::Mul(out12, hn::Set(df, scale.raw[12]));
out13 = hn::Mul(out13, hn::Set(df, scale.raw[13]));
out14 = hn::Mul(out14, hn::Set(df, scale.raw[14]));
out15 = hn::Mul(out15, hn::Set(df, scale.raw[15]));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd4(df, x1, c1, out0, out1, out2, out3);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd4(df, x2, c2, out0, out1, out2, out3);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd4(df, x3, c3, out0, out1, out2, out3);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd4(df, x4, c4, out0, out1, out2, out3);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd4(df, x5, c5, out0, out1, out2, out3);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd4(df, x6, c6, out0, out1, out2, out3);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd4(df, x7, c7, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
}
const size_t remaining = size - i;
HWY_DASSERT(remaining == 0);
}
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
// corresponding values in c0 and adds them to the NF rows of out.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE;
const size_t NF = hn::MaxLanes(df);
size_t i = 0;
while (i + NF <= size) {
if constexpr (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
out8 = hn::Mul(out8, hn::Set(df, scale.raw[8]));
out9 = hn::Mul(out9, hn::Set(df, scale.raw[9]));
out10 = hn::Mul(out10, hn::Set(df, scale.raw[10]));
out11 = hn::Mul(out11, hn::Set(df, scale.raw[11]));
out12 = hn::Mul(out12, hn::Set(df, scale.raw[12]));
out13 = hn::Mul(out13, hn::Set(df, scale.raw[13]));
out14 = hn::Mul(out14, hn::Set(df, scale.raw[14]));
out15 = hn::Mul(out15, hn::Set(df, scale.raw[15]));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
} else if constexpr (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
out4 = hn::Mul(out4, hn::Set(df, scale.raw[4]));
out5 = hn::Mul(out5, hn::Set(df, scale.raw[5]));
out6 = hn::Mul(out6, hn::Set(df, scale.raw[6]));
out7 = hn::Mul(out7, hn::Set(df, scale.raw[7]));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
} else if constexpr (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scale.raw[0]));
out1 = hn::Mul(out1, hn::Set(df, scale.raw[1]));
out2 = hn::Mul(out2, hn::Set(df, scale.raw[2]));
out3 = hn::Mul(out3, hn::Set(df, scale.raw[3]));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
} else {
HWY_DASSERT(false);
}
i += NF;
}
const size_t remaining = size - i;
HWY_DASSERT(remaining == 0);
}
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2. // TODO: support bf16 logits using Decompress2.
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,