No public description

PiperOrigin-RevId: 877333188
This commit is contained in:
The gemma.cpp Authors 2026-03-02 04:31:45 -08:00 committed by Copybara-Service
parent 16c1b29b89
commit a3d994915f
18 changed files with 1234 additions and 1507 deletions

View File

@ -548,7 +548,6 @@ cc_library(
deps = [
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",
@ -596,11 +595,6 @@ cc_test(
INTERNAL_DEPS = []
cc_library(
name = "flash_structs",
hdrs = ["gemma/flash_structs.h"],
)
cc_library(
name = "attention",
srcs = [
@ -610,6 +604,7 @@ cc_library(
hdrs = [
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
],
textual_hdrs = [
"gemma/gemma-inl.h",
@ -618,7 +613,6 @@ cc_library(
":activations",
":basics",
":configs",
":flash_structs",
":kv_cache",
":mat",
":matmul",

View File

@ -24,7 +24,6 @@
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/flash_structs.h"
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "gemma/tensor_stats.h"
@ -53,13 +52,10 @@ struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
size_t max_workers, const Allocator& allocator,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
layer_config.heads),
// `vocab_size == 0` means it is for Vit part, VitAttention
// is still MHA and does not use an external KV cache.
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
@ -90,9 +86,6 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
layer_config.heads * layer_config.qkv_dim,
allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
@ -114,11 +107,6 @@ struct AttentionActivations {
}
return;
}
// This is a guess at the maximum number of params we might need to avoid
// reallocations. The actual number of params is determined by the number of
// query tiles, which is not known here.
flash_params.reserve(batch_size * layer_config.heads);
split_flash_params.reserve(batch_size * layer_config.heads);
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
@ -142,10 +130,6 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_out_reps.OverrideRows(batch_size * rep_factor);
// There is no override for [split_]flash_params, because we reserved an
// upper bound, and flash attention controls the actual size when it
// calculates the size and number of tiles.
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
@ -153,15 +137,6 @@ struct AttentionActivations {
// `inv_timescale*` are not batched.
}
// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
// the number of query rows and 1/8th of that.
std::vector<Tile148Params> flash_params;
// Parameters for flash attention, split by k-position. May be significantly
// larger than flash_params in decode mode, when the number of query rows is
// small.
std::vector<Tile148Params> split_flash_params;
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
@ -173,7 +148,6 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> att_out_reps; // attention output for each thread.
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads
@ -190,26 +164,19 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Replication factor to help evenly share work over threads.
static constexpr size_t kThreadReplicationFactor = 4;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
std::vector<Tile148Params>& flash_params,
std::vector<Tile148Params>& split_flash_params)
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
activations.split_flash_params) {
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
@ -219,7 +186,6 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_out_reps = activations.att_out_reps;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
@ -250,9 +216,6 @@ struct AttentionActivationsPtrs {
}
const ModelConfig& config;
// Parameters for flash attention.
std::vector<Tile148Params>& flash_params;
std::vector<Tile148Params>& split_flash_params;
// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
@ -278,7 +241,6 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
MatPtrT<float> att_out_reps;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
@ -343,8 +305,7 @@ struct Activations {
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs),
runtime_config, ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

View File

@ -49,39 +49,6 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Returns the number of floats per vector (aka NF).
size_t FloatsPerVector() {
using DF = hn::ScalableTag<float>;
const DF df;
return hn::Lanes(df);
}
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}
// Transposes a single row of the kv cache into the k-cache and v-cache.
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
// This is inefficient, as the writes are scattered over cache lines, but it
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
}
}
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
@ -313,11 +280,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
@ -337,26 +299,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
// Note that k_cache and v_cache are different shapes.
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
@ -377,10 +319,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
/*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
// This is inefficient, as multiple threads are writing the same K
// cache line, but the input is generated by a matmul, so it is
// difficult to change, and it probably isn't significant.
TransposeKVCacheRow(kv, k, v, qkv_dim);
});
}
@ -403,8 +341,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
AttentionActivations::kThreadReplicationFactor,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx, attention_impl);
}

View File

@ -31,13 +31,6 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
\
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \

View File

@ -1,10 +1,8 @@
#include <cstddef>
#include <cstdlib>
#include <cstring> // strcmp
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -107,8 +105,7 @@ struct TestAttentionState {
tokens(num_tokens),
attention_storage_(model_state.config, model_state.layer_config,
batch_size, num_tokens, runtime_config,
state.ctx.pools.MaxWorkers(), state.ctx.allocator,
row_ptrs_),
state.ctx.allocator, row_ptrs_),
attention(model_state.config, num_tokens, attention_storage_) {
for (size_t i = 0; i < qbatch_size; ++i) {
kv_caches.emplace_back(model_state.config, inference_args,
@ -146,7 +143,6 @@ struct TestAttentionState {
};
double GetTolerance() {
if (IsBF16<KV_t>()) return 1e-2;
const char* target_name = hwy::TargetName(HWY_TARGET);
if (strncmp(target_name, "AVX2", 4) == 0) {
return 2e-2;
@ -159,20 +155,6 @@ double GetTolerance() {
}
}
template <typename T>
bool CompareArraySimilar(const T* expected, const T* actual, size_t count,
const char* target_name, const char* filename,
int line) {
if constexpr (IsBF16<KV_t>()) {
constexpr double kTolerance = 3e-2;
return hwy::CompareArraySimilar(expected, actual, count, kTolerance,
target_name, filename, line);
} else {
return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(),
target_name, filename, line);
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareAttSumsWithGolden(
const AttentionActivationsPtrs& attention,
@ -188,9 +170,9 @@ void CompareAttSumsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
}
EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(),
kDims, hwy::TargetName(HWY_TARGET),
__FILE__, __LINE__))
EXPECT_TRUE(hwy::CompareArraySimilar(
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
}
}
@ -218,20 +200,19 @@ void CompareKVCacheWithGolden(
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const BF16* cache_row =
const float* cache_row =
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
for (size_t j = 0; j < kDims; ++j) {
actual_k_row[j] = hwy::ConvertScalarTo<float>(cache_row[kv_offset + j]);
actual_v_row[j] =
hwy::ConvertScalarTo<float>(cache_row[kv_offset + qkv_dim + j]);
actual_k_row[j] = cache_row[kv_offset + j];
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
}
EXPECT_TRUE(CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims,
EXPECT_TRUE(hwy::CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
EXPECT_TRUE(CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims,
EXPECT_TRUE(hwy::CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
@ -257,8 +238,8 @@ void CompareQVecsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_q_row[j] = q_row[head_offset + j];
}
EXPECT_TRUE(CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims,
EXPECT_TRUE(hwy::CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
<< " q_head=" << q_head;
@ -282,46 +263,46 @@ const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats
// Layer 0
const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{46.5, 56.5, 10.0625, 65.5, -2.239375, 135, 15.8125, 51, -100, 52.5,
{{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5,
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875,
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
128, 27.25, -161, 19.125, -58, 97.5},
{-17.625, -15.375, 135, -13.4375, -3.343, -45.75, 29.625, 93, 18.625, 75.5,
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125,
6.4625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.62165625, -46.5, -152, -2.9375, -81}},
{{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75,
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
-46, -22, -19.375, -16.125, -148, 20.875},
{-47, -17.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
10.9375, -52.5, 23.25, 75}},
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25,
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
-34.75, 18, -52, 100, -186, -75.5, 50.75},
{7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.2675, 82.5,
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
39.25, 65, 47.25, -89.5, -34.25, 137}},
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -12.625, 38.5, 32.5,
53.75, 109, 4.62375, 57.5, -20.5, 132},
{143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5,
195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}},
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
{137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5,
88.5, 48.5, -46.25, -35.5, 101.5}},
{{-103, -41.5, 39, -52, -62.7, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5,
125, -53.5, -11.625, 262},
{28.075, 6.64375, -36.75, -13.35, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
81, -28.5, -3.875, 111, -167, 59, 176}},
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
88.5, 48.5, -46.25, -36.75, 101.5}},
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
125, -53.5, -14.625, 262},
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
81, -27.25, -3.03125, 111, -167, 59, 176}},
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
{40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5,
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
58.25, 173, -53.5, 25.125, 4.8125}},
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
-17.125, 15.25, 143, -27, 103, 101},
{-77, 40.75, -10.5, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
8.55, -99.5, 14.6875, -12.25, 33}},
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
8.125, -99.5, 13.6875, -11.6875, 33}},
};
// Layer 0, *K*V Head 0

View File

@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) {
}
enum class AttentionImpl {
kOld, // Previous Attention implementation
kFlash, // Flash Attention (default)
kOld,
kFlash,
kFlashTransposedQs,
kFlashTransposedQsBF16,
kSentinel,

File diff suppressed because it is too large Load Diff

View File

@ -50,6 +50,25 @@ namespace gcpp {
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
const size_t worker); \
\
void TileFlashAttention( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, \
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \
const size_t min_last_pos, const size_t max_last_pos, \
const MatPtrT<KV_t>& v, const size_t layer_idx, \
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
@ -73,6 +92,7 @@ namespace gcpp {
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -62,17 +62,16 @@ namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename T>
void SetMat(const size_t offset, MatPtrT<T>& mat) {
void SetMat(const size_t offset, MatPtrT<float>& mat) {
const size_t kOuter = mat.Extents().rows;
const size_t kInner = mat.Extents().cols;
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
T* HWY_RESTRICT row = mat.Row(i);
float* HWY_RESTRICT row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] = hwy::ConvertScalarTo<T>(
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale)));
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
}
}
}
@ -95,15 +94,14 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
if (rel_abs_delta > 0.0f) {
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
}
EXPECT_LT(rel_abs_delta, 1e-3)
EXPECT_LT(rel_abs_delta, 1e-5)
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
<< c << "]=" << b_row[c];
}
}
}
void TestFlashAttention(size_t target_parallelism,
AttentionImpl attention_impl) {
void TestFlashAttention(size_t target_parallelism) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
constexpr size_t kOuter = 1024;
@ -132,9 +130,9 @@ void TestFlashAttention(size_t target_parallelism,
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(
config, layer_config, batch_size, kOuter, runtime_config,
ctx.pools.MaxWorkers(), ctx.allocator, row_ptrs);
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, runtime_config, ctx.allocator,
row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);
@ -144,10 +142,7 @@ void TestFlashAttention(size_t target_parallelism,
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
auto& kvc = qbatch.KV(0).kv_cache;
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
@ -158,17 +153,6 @@ void TestFlashAttention(size_t target_parallelism,
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
SetMat(h + layer_config.heads, k);
SetMat(h + layer_config.heads * 2, v);
for (size_t p = 0; p < tokens.size(); ++p) {
KV_t* HWY_RESTRICT k_src = k.Row(p);
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * 2;
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * kFloatsPerTile;
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
}
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
@ -183,19 +167,18 @@ void TestFlashAttention(size_t target_parallelism,
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
total_tasks, target_parallelism);
printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n",
target_parallelism, kNF, kVTileSize,
GetAttentionImplName(attention_impl).c_str());
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
attention, qbatch, ctx, attention_impl);
attention, qbatch, ctx, AttentionImpl::kFlash);
AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
}
void TestAttention() {
TestFlashAttention(8192, AttentionImpl::kFlash);
TestFlashAttention(2048, AttentionImpl::kFlash);
TestFlashAttention(256, AttentionImpl::kFlash);
TestFlashAttention(8192);
TestFlashAttention(2048);
TestFlashAttention(256);
}
const std::vector<float> exp_denominator_sums_gold = {

View File

@ -2,19 +2,11 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <stdint.h>
#include <limits>
namespace gcpp {
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 4 for 4 Q-rows.
static constexpr size_t k4xNFVTileSize = 4;
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 8 for 8 Q-rows.
static constexpr size_t k8xNFVTileSize = 8;
// State for computing softmax in a streaming ("online") manner,
// avoiding large intermediate values by subtracting the running maximum.
// For a sequence x_1, ..., x_n:
@ -28,46 +20,10 @@ struct OnlineSoftmaxState {
float d = 0.0f;
};
struct Tile4FlashState {
OnlineSoftmaxState row_states[k8xNFVTileSize];
};
static constexpr size_t kVTileSize4 = 4;
// Parameters for a strip of tiles of flash attention. For processing a strip
// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF
// k-positions. The total width of the strip might cover the entire sequence,
// or a part of it, depending on whether the strip has been split.
struct Tile148Params {
// Vertical tile size gives the number used in the k8xNFVTileSize arrays.
// It is the number of Q rows in the tile.
uint32_t v_tile_size = 0;
// min start position across all rows in the tile determines the
// mask used for the tile.
uint32_t min_start_pos = std::numeric_limits<uint32_t>::max();
// max last position across all rows in the tile determines the mask
// used for the tile.
uint32_t max_last_pos = 0;
// Index into the qbatch.KV is the same for each row in the tile.
uint32_t qi_index;
// Index into the kv_cache is the same for each row in the tile.
uint32_t kv_offset;
// In the original task, the index to the split tasks of the first split task.
uint32_t split_index = 0;
// The index of the split for running split attention.
uint32_t i_of_n = 0;
// The number of splits for running split attention.
uint32_t n_of_n = 0;
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[k8xNFVTileSize];
// Offsets into att_out for each row in the tile.
uint32_t out_offsets[k8xNFVTileSize];
// Start k-positions for each row in the tile.
uint32_t start_pos[k8xNFVTileSize];
// Last k-positions for each row in the tile. Inclusive.
uint32_t last_pos[k8xNFVTileSize];
// Row index to att_out.
uint32_t tq_idx[k8xNFVTileSize];
// Flash attention state for the tile.
Tile4FlashState end_state;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
};
} // namespace gcpp

View File

@ -29,11 +29,6 @@
namespace gcpp {
// TODO: rays - Remove this once hwy is updated.
#ifndef HWY_ARCH_MAX_BYTES
#define HWY_ARCH_MAX_BYTES 256
#endif
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
@ -48,23 +43,6 @@ static size_t CappedSeqLen(const ModelConfig& config,
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
// WARNING: the rows and cols of k_cache and v_cache will be modified
// before use!
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
// cols will be increased by 2xkFloatsPerVector on first use. This is to
// avoid making KVCache another class that has to be duplicated for each
// machine architecture, since kFloatsPerVector is architecture dependent.
// The change is shape is safe only if the padding is kPacked.
k_cache("k",
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
allocator, MatPadding::kPacked),
v_cache("v",
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
allocator, MatPadding::kPacked),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
@ -79,16 +57,14 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
: allocator_(allocator) {
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|| ((runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs
) &&
hwy::IsSame<KV_t, BF16>())) {
) {
const size_t num_tiles =
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
tiled_seq_len = num_tiles * kTileSize;
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
Type kv_cache_type;
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|| hwy::IsSame<KV_t, BF16>()) {
) {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
} else {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
@ -138,6 +114,9 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr);
copy.tiled_seq_len = tiled_seq_len;
return copy;
}

View File

@ -30,7 +30,7 @@
namespace gcpp {
using KV_t = BF16;
using KV_t = float;
struct KVCache;
// A non-owning view of a KVCache.
@ -40,8 +40,6 @@ struct KVCachePtr {
bool IsTiled() const;
MatPtrT<KV_t> kv_cache;
MatPtrT<KV_t> k_cache;
MatPtrT<KV_t> v_cache;
KVCache* cache = nullptr;
};
@ -125,33 +123,11 @@ struct KVCache {
// kv_head_ptrs[...].Rows().
std::vector<MatPtr> kv_head_ptrs;
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
// The format of k_cache indicates that there are pairs of values from
// qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence,
// in groups of qkv_dim/2 elements in groups of kv_heads elements.
// This enables sequential loading of the data when filling 2 vectors with
// NF sequence elements of pairs of BF16 qkv values. The next vector then
// continues reading the rest of qkv.
// [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2]
MatStorageT<KV_t> k_cache;
// v_cache is formatted to allow sequential access to V during scaling and
// update of att_out.
// Originally [seq_len, layers * kv_heads * qkv_dim]
// v_cache is transposed to:
// [layers, kv_heads, seq_len, qkv_dim], reshaped to:
// [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF]
// then transposed to:
// [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF]
// and finally packed in a 2D MatStorageT as:
// [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF]
// This allows sequential reads of 2NF registers each of 2NF BF16 values,
// repeatedly until all of qkv_dim is read.
MatStorageT<KV_t> v_cache;
KVCachePtr ToPtr() {
return KVCachePtr{
.kv_cache = kv_cache,
.k_cache = k_cache,
.v_cache = v_cache,
.cache = this,
};
}

View File

@ -98,7 +98,7 @@ struct AttentionTestEnv {
}
}
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
MatPtrT<KV_t> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
MatPtrT<float> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
FillMatPtrT(compact_kv_cache);
} else {
FillMatPtrT(kv_caches.back().kv_cache);
@ -735,13 +735,12 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(TiledAttentionTest);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries);
// TODO() Fix the goldens for the change in KV_t to BF16
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
// TestLocalAttentionForAllHeadsTokensAndBatch);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
TestLocalAttentionForAllHeadsTokensAndBatch);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16);
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
// TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
HWY_AFTER_TEST();

View File

@ -613,6 +613,267 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
});
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale));
sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale));
sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale));
sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale));
sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale));
sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale));
sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale));
sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8);
sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9);
sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10);
sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11);
sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12);
sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13);
sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14);
sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2,
VF& sum3) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
}
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd4(df, x1, c1, out0, out1, out2, out3);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd4(df, x2, c2, out0, out1, out2, out3);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd4(df, x3, c3, out0, out1, out2, out3);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd4(df, x4, c4, out0, out1, out2, out3);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd4(df, x5, c5, out0, out1, out2, out3);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd4(df, x6, c6, out0, out1, out2, out3);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd4(df, x7, c7, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
}
HWY_DASSERT(size == i);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
const VF c1, const VF c2, const VF c3,
@ -625,147 +886,145 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4(
DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c,
const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a,
VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF c1 = hn::Set(df, *c++);
VF c2 = hn::Set(df, *c++);
VF c3 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a);
MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b);
}
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0,
const VF c1, const VF c2,
const VF c3, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
sum3);
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
sum3);
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
sum3);
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
sum3);
}
// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows
// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out,
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
sum3);
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
sum3);
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
sum3);
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
sum3);
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
sum3);
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
sum2, sum3);
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
sum2, sum3);
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
sum2, sum3);
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
sum2, sum3);
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
sum2, sum3);
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
sum2, sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of 2NF.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const VF c10, const VF c11, const VF c20, const VF c21, const VF c30,
const VF c31, const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos,
size_t num_lanes, float* HWY_RESTRICT out,
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[8 * kMaxNF];
hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem);
hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF);
size_t i = 0;
while (i + NF * 2 <= size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::Load(df, out + i + out_offsets[0]);
out1a = hn::Load(df, out + i + out_offsets[1]);
out2a = hn::Load(df, out + i + out_offsets[2]);
out3a = hn::Load(df, out + i + out_offsets[3]);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
VF scale3 = hn::Set(df, scales[3]);
out0a = hn::Mul(out0a, scale0);
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out1a, df, out + i + out_offsets[1]);
hn::Store(out2a, df, out + i + out_offsets[2]);
hn::Store(out3a, df, out + i + out_offsets[3]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
i += NF * 2;
v_bf += 4 * NF * NF;
while (i + NF <= size) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 8) {
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 16) {
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
out3);
}
}
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
i += NF;
}
HWY_DASSERT(size == i);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df,
const BF16* HWY_RESTRICT v,
const float* HWY_RESTRICT c,
const size_t num_lanes,
VF& sum0a, VF& sum0b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
sum0a = hn::MulAdd(v0a, c0, sum0a);
sum0b = hn::MulAdd(v0b, c0, sum0b);
}
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos, size_t num_lanes,
float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[2 * kMaxNF];
hn::Store(c00, df, c_mem);
hn::Store(c01, df, c_mem + NF);
size_t i = 0;
while (i + NF * 2 <= size) {
VF out0a, out0b;
out0a = hn::Load(df, out + i + out_offsets[0]);
VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out0b = hn::Mul(out0b, scale0);
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
while (i < size) {
float sum = out[i + out_offsets[0]] * scales[0];
const BF16* HWY_RESTRICT v_local = v_bf;
for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF);
++lane, v_local += 2 * NF) {
sum += hwy::ConvertScalarTo<float>(*v_local) * c_mem[lane];
}
++i;
++v_bf;
}
}
template <int32_t N, typename DF, class VF = hn::Vec<DF>>
static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT<float>& out,
size_t start_col, VF out0_0, VF out0_1,
@ -1211,6 +1470,104 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
HWY_DASSERT(qkv_dim == i);
}
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
// corresponding values in c0 and adds them to the NF rows of out.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
}
HWY_DASSERT(size == i);
}
// See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2.
// Computes softmax probabilities for the given logits, normalizing in-place.

View File

@ -202,16 +202,6 @@ class MatPtr : public IFields {
override_rows_ = static_cast<uint32_t>(rows);
}
// Changes the number of rows and columns without reallocating the memory.
// Increases cols by factor and reduces rows by factor.
// The rows must be divisible by factor and the matrix must be packed.
void ReshapePackedRowsToCols(size_t factor) {
HWY_ASSERT(IsPacked());
private_rows_ = hwy::DivCeil(private_rows_, factor);
cols_ *= factor;
stride_ *= factor;
}
// Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; }

View File

@ -106,8 +106,7 @@ template <typename T>
void FillMatPtrT(MatPtrT<T>& mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
mat.Row(i)[j] =
hwy::ConvertScalarTo<T>(hwy::Unpredictable1() * 0.01f * (i + j + 1));
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
}
}
}

View File

@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) {
return "FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionTileFlashAttention1:
return "FlashAttention.TileFlashAttention1";
case Zones::kFlashAttentionSingleFlashAttention:
return "FlashAttention.SingleFlashAttention";
case Zones::kFlashAttentionTileFlashAttention:
return "FlashAttention.TileFlashAttention";
case Zones::kFlashAttentionTileFlashAttention4:
return "FlashAttention.TileFlashAttention4";
case Zones::kFlashAttentionTileFlashAttention8:
return "FlashAttention.TileFlashAttention8";
case Zones::kFlashAttentionCombineSplit:
return "FlashAttention.CombineSplit";
case Zones::kFlashAttentionTransposeQ:
return "FlashAttention.TransposeQ";
case Zones::kGenActivation:
return "Gen.Activation";
case Zones::kGenActivationFused:

View File

@ -14,10 +14,10 @@ enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionTileFlashAttention1,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention4,
kFlashAttentionTileFlashAttention8,
kFlashAttentionCombineSplit,
kFlashAttentionTransposeQ,
kGenActivation,
kGenActivationFused,
kGenAttention,