Rewrote flash attention to use BF16, transpose k and v, rewrote the task distribution, increase parallelism on decode, and use double the registers for the core of flash attention.

PiperOrigin-RevId: 868146247
This commit is contained in:
Ray Smith 2026-02-10 07:55:17 -08:00 committed by Copybara-Service
parent 76d7951242
commit a814aa411e
17 changed files with 1424 additions and 1169 deletions

View File

@ -547,6 +547,7 @@ cc_library(
deps = [
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",
@ -594,6 +595,11 @@ cc_test(
INTERNAL_DEPS = []
cc_library(
name = "flash_structs",
hdrs = ["gemma/flash_structs.h"],
)
cc_library(
name = "attention",
srcs = [
@ -603,7 +609,6 @@ cc_library(
hdrs = [
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
],
textual_hdrs = [
"gemma/gemma-inl.h",
@ -612,6 +617,7 @@ cc_library(
":activations",
":basics",
":configs",
":flash_structs",
":kv_cache",
":mat",
":matmul",

View File

@ -24,6 +24,7 @@
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/flash_structs.h"
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "gemma/tensor_stats.h"
@ -52,10 +53,13 @@ struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
const Allocator& allocator,
size_t max_workers, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
: rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
layer_config.heads),
// `vocab_size == 0` means it is for Vit part, VitAttention
// is still MHA and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
@ -86,6 +90,9 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
layer_config.heads * layer_config.qkv_dim,
allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
@ -107,6 +114,11 @@ struct AttentionActivations {
}
return;
}
// This is a guess at the maximum number of params we might need to avoid
// reallocations. The actual number of params is determined by the number of
// query tiles, which is not known here.
flash_params.reserve(batch_size * layer_config.heads);
split_flash_params.reserve(batch_size * layer_config.heads);
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
@ -130,6 +142,7 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_out_reps.OverrideRows(batch_size * rep_factor);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
@ -137,6 +150,15 @@ struct AttentionActivations {
// `inv_timescale*` are not batched.
}
// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
// the number of query rows and 1/8th of that.
std::vector<FlashAttentionParams> flash_params;
// Parameters for flash attention, split by k-position. May be significantly
// larger than flash_params in decode mode, when the number of query rows is
// small.
std::vector<FlashAttentionParams> split_flash_params;
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
@ -148,6 +170,7 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> att_out_reps; // attention output for each thread.
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads
@ -156,19 +179,27 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Replication factor to help evenly share work over threads.
static constexpr size_t kThreadReplicationFactor = 4;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
AttentionActivationsPtrs(
const ModelConfig& config, size_t seq_len,
std::vector<FlashAttentionParams>& flash_params,
std::vector<FlashAttentionParams>& split_flash_params)
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
activations.split_flash_params) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
@ -178,6 +209,7 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_out_reps = activations.att_out_reps;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
@ -208,6 +240,9 @@ struct AttentionActivationsPtrs {
}
const ModelConfig& config;
// Parameters for flash attention.
std::vector<FlashAttentionParams>& flash_params;
std::vector<FlashAttentionParams>& split_flash_params;
// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
@ -233,6 +268,7 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
MatPtrT<float> att_out_reps;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
@ -287,7 +323,8 @@ struct Activations {
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config, ctx.allocator, row_ptrs),
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

View File

@ -49,6 +49,39 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Returns the number of floats per vector (aka NF).
size_t FloatsPerVector() {
using DF = hn::ScalableTag<float>;
const DF df;
return hn::Lanes(df);
}
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}
// Transposes a single row of the kv cache into the k-cache and v-cache.
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
// This is inefficient, as the writes are scattered over cache lines, but it
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
}
}
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
@ -280,6 +313,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();
// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
@ -299,6 +337,26 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
// Note that k_cache and v_cache are different shapes.
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
@ -319,6 +377,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
/*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
// This is inefficient, as multiple threads are writing the same K
// cache line, but the input is generated by a matmul, so it is
// difficult to change, and it probably isn't significant.
TransposeKVCacheRow(kv, k, v, qkv_dim);
});
}
@ -341,7 +403,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
AttentionActivations::kThreadReplicationFactor,
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx, attention_impl);
}

View File

@ -31,6 +31,13 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
\
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \

View File

@ -1,8 +1,10 @@
#include <cstddef>
#include <cstdlib>
#include <cstring> // strcmp
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <vector>
#include "gtest/gtest.h"
@ -105,7 +107,8 @@ struct TestAttentionState {
tokens(num_tokens),
attention_storage_(model_state.config, model_state.layer_config,
batch_size, num_tokens, runtime_config,
state.ctx.allocator, row_ptrs_),
state.ctx.pools.MaxWorkers(), state.ctx.allocator,
row_ptrs_),
attention(model_state.config, num_tokens, attention_storage_) {
for (size_t i = 0; i < qbatch_size; ++i) {
kv_caches.emplace_back(model_state.config, inference_args,
@ -143,6 +146,7 @@ struct TestAttentionState {
};
double GetTolerance() {
if (IsBF16<KV_t>()) return 1e-2;
const char* target_name = hwy::TargetName(HWY_TARGET);
if (strncmp(target_name, "AVX2", 4) == 0) {
return 2e-2;
@ -155,6 +159,57 @@ double GetTolerance() {
}
}
// Comparison function for computations that used BF16, whether the result is
// stored in BF16 or F32.
// Compare with absolute tolerance for values with small magnitudes.
// Compare with relative tolerance for values with larger magnitudes.
template <typename T>
bool CompareArraySimilarBF16(const T* expected, const T* actual, size_t count,
const char* target_name, const char* filename,
int line) {
constexpr double kTolerance = 3e-2;
for (size_t i = 0; i < count; ++i) {
const double exp = hwy::ConvertScalarTo<double>(expected[i]);
const double act = hwy::ConvertScalarTo<double>(actual[i]);
const double l1 = std::abs(act - exp);
// Cannot divide, so check absolute error.
if (std::abs(exp) <= 1.0) {
if (l1 > kTolerance) {
std::string array_values = hwy::detail::FormatMismatchedArrays(
expected, actual, count, kTolerance);
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E l1 %E tol %E%s\n",
target_name, filename, line, "BF16", i, count, exp, act, l1,
kTolerance, array_values.c_str());
return false;
}
} else { // relative
const double rel = l1 / exp;
if (rel > kTolerance) {
std::string array_values = hwy::detail::FormatMismatchedArrays(
expected, actual, count, kTolerance);
HWY_WARN("%s %s:%d %s mismatch %zu of %zu: %E %E rel %E tol %E%s\n",
target_name, filename, line, "BF16", i, count, exp, act, rel,
kTolerance, array_values.c_str());
return false;
}
}
}
return true;
}
template <typename T>
bool CompareArraySimilar(const T* expected, const T* actual, size_t count,
const char* target_name, const char* filename,
int line) {
if constexpr (IsBF16<KV_t>()) {
return CompareArraySimilarBF16(expected, actual, count, target_name,
filename, line);
} else {
return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(),
target_name, filename, line);
}
}
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
void CompareAttSumsWithGolden(
const AttentionActivationsPtrs& attention,
@ -170,9 +225,9 @@ void CompareAttSumsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
}
EXPECT_TRUE(hwy::CompareArraySimilar(
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(),
kDims, hwy::TargetName(HWY_TARGET),
__FILE__, __LINE__))
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
}
}
@ -200,19 +255,20 @@ void CompareKVCacheWithGolden(
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
const float* cache_row =
const BF16* cache_row =
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
for (size_t j = 0; j < kDims; ++j) {
actual_k_row[j] = cache_row[kv_offset + j];
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
actual_k_row[j] = hwy::ConvertScalarTo<float>(cache_row[kv_offset + j]);
actual_v_row[j] =
hwy::ConvertScalarTo<float>(cache_row[kv_offset + qkv_dim + j]);
}
EXPECT_TRUE(hwy::CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
k_golden[token_idx][qi], actual_k_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
EXPECT_TRUE(hwy::CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
v_golden[token_idx][qi], actual_v_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
<< " kv_head=" << kv_head;
@ -238,8 +294,8 @@ void CompareQVecsWithGolden(
for (size_t j = 0; j < kDims; ++j) {
actual_q_row[j] = q_row[head_offset + j];
}
EXPECT_TRUE(hwy::CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
EXPECT_TRUE(CompareArraySimilar(
q_golden[token_idx][qi], actual_q_row.get(), kDims,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
<< " q_head=" << q_head;
@ -267,42 +323,42 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875,
128, 27.25, -161, 19.125, -58, 97.5},
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125,
6.4625, 150, 144, -155, -47.25, -98.5, 3.5625},
{-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
-66.5, -0.62165625, -46.5, -152, -2.9375, -81}},
{{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75,
-46, -22, -19.375, -16.125, -148, 20.875},
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
{-47, -19.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
10.9375, -52.5, 23.25, 75}},
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25,
-34.75, 18, -52, 100, -186, -75.5, 50.75},
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
{7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
39.25, 65, 47.25, -89.5, -34.25, 137}},
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
32.5, 53.75, 109, 4.62375, 57.5, -20.5, 132},
{143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5,
195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}},
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
{137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5,
88.5, 48.5, -46.25, -36.75, 101.5}},
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5,
125, -53.5, -14.625, 262},
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
{28.075, 6.64375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
81, -27.25, -3.03125, 111, -167, 59, 176}},
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
{40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5,
58.25, 173, -53.5, 25.125, 4.8125}},
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
-17.125, 15.25, 143, -27, 103, 101},
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
8.125, -99.5, 13.6875, -11.6875, 33}},
8.55, -99.5, 14.6875, -11.6875, 33}},
};
// Layer 0, *K*V Head 0

View File

@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) {
}
enum class AttentionImpl {
kOld,
kFlash,
kOld, // Previous Attention implementation
kFlash, // Flash Attention (default)
kSentinel,
};

File diff suppressed because it is too large Load Diff

View File

@ -44,15 +44,6 @@ namespace gcpp {
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\

View File

@ -62,16 +62,17 @@ namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
void SetMat(const size_t offset, MatPtrT<float>& mat) {
template <typename T>
void SetMat(const size_t offset, MatPtrT<T>& mat) {
const size_t kOuter = mat.Extents().rows;
const size_t kInner = mat.Extents().cols;
const float i_scale = 1.0f / kInner;
const float j_scale = 1.0f / kOuter;
for (size_t i = 0; i < kOuter; ++i) {
float* row = mat.Row(i);
T* row = mat.Row(i);
for (size_t j = 0; j < kInner; ++j) {
row[j] =
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
row[j] = hwy::ConvertScalarTo<T>(
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale)));
}
}
}
@ -94,14 +95,15 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
if (rel_abs_delta > 0.0f) {
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
}
EXPECT_LT(rel_abs_delta, 1e-5)
EXPECT_LT(rel_abs_delta, 1e-3)
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
<< c << "]=" << b_row[c];
}
}
}
void TestFlashAttention(size_t target_parallelism) {
void TestFlashAttention(size_t target_parallelism,
AttentionImpl attention_impl) {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
constexpr size_t kOuter = 1024;
@ -131,7 +133,8 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, runtime_config, ctx.allocator,
kOuter, runtime_config,
ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
@ -142,7 +145,10 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
auto& kvc = qbatch.KV(0).kv_cache;
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
@ -153,6 +159,17 @@ void TestFlashAttention(size_t target_parallelism) {
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
SetMat(h + layer_config.heads, k);
SetMat(h + layer_config.heads * 2, v);
for (size_t p = 0; p < tokens.size(); ++p) {
KV_t* HWY_RESTRICT k_src = k.Row(p);
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * 2;
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * kFloatsPerTile;
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
}
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
@ -167,18 +184,19 @@ void TestFlashAttention(size_t target_parallelism) {
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n",
target_parallelism, kNF, kVTileSize,
GetAttentionImplName(attention_impl).c_str());
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
attention, qbatch, ctx, AttentionImpl::kFlash);
attention, qbatch, ctx, attention_impl);
AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
}
void TestAttention() {
TestFlashAttention(8192);
TestFlashAttention(2048);
TestFlashAttention(256);
TestFlashAttention(8192, AttentionImpl::kFlash);
TestFlashAttention(2048, AttentionImpl::kFlash);
TestFlashAttention(256, AttentionImpl::kFlash);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -2,11 +2,19 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
#include <stddef.h>
#include <stdint.h>
#include <limits>
namespace gcpp {
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 4 for 4 Q-rows.
static constexpr size_t k4xNFVTileSize = 4;
// The vertical tile size in flash attention when register lanes correspond to
// K-timesteps, and the number of registers is 8 for 8 Q-rows.
static constexpr size_t k8xNFVTileSize = 8;
// State for computing softmax in a streaming ("online") manner,
// avoiding large intermediate values by subtracting the running maximum.
// For a sequence x_1, ..., x_n:
@ -20,10 +28,44 @@ struct OnlineSoftmaxState {
float d = 0.0f;
};
static constexpr size_t kVTileSize4 = 4;
struct Tile4FlashState {
OnlineSoftmaxState row_states[kVTileSize4];
OnlineSoftmaxState row_states[k8xNFVTileSize];
};
// Parameters for a strip of tiles of flash attention. For processing a strip
// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF
// k-positions. The total width of the strip might cover the entire sequence,
// or a part of it, depending on whether the strip has been split.
struct FlashAttentionParams {
// Vertical tile size gives the number used in the k8xNFVTileSize arrays.
// It is the number of Q rows in the tile.
uint32_t v_tile_size = 0;
// min start position across all rows in the tile determines the
// mask used for the tile.
uint32_t min_start_pos = std::numeric_limits<uint32_t>::max();
// max last position across all rows in the tile determines the mask
// used for the tile.
uint32_t max_last_pos = 0;
// Index into the qbatch.KV is the same for each row in the tile.
uint32_t qi_index;
// Index into the kv_cache is the same for each row in the tile.
uint32_t kv_offset;
// In the original task, the index to the split tasks of the first split task.
uint32_t split_index = 0;
// The index of the split for running split attention.
uint32_t i_of_n = 0;
// Offsets into original Q for each row in the tile.
uint32_t q_offsets[k8xNFVTileSize];
// Offsets into att_out for each row in the tile.
uint32_t out_offsets[k8xNFVTileSize];
// Start k-positions for each row in the tile.
uint32_t start_pos[k8xNFVTileSize];
// Last k-positions for each row in the tile. Inclusive.
uint32_t last_pos[k8xNFVTileSize];
// Row index to att_out.
uint32_t tq_idx[k8xNFVTileSize];
// Flash attention state for the tile.
Tile4FlashState end_state;
};
} // namespace gcpp

View File

@ -43,6 +43,17 @@ static size_t CappedSeqLen(const ModelConfig& config,
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
// WARNING: the rows and cols of k_cache and v_cache will be modified
// before use!
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
// cols will be increased by 2xkFloatsPerVector on first use. This is to
// avoid making KVCache another class that has to be duplicated for each
// machine architecture, since kFloatsPerVector is architecture dependent.
// The change is shape is safe only if the padding is kPacked.
k_cache("k", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
MatPadding::kPacked),
v_cache("v", Extents2D(kv_extents.rows, kv_extents.cols / 2), allocator,
MatPadding::kPacked),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
@ -55,6 +66,8 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
CopyMat(k_cache, copy.k_cache);
CopyMat(v_cache, copy.v_cache);
return copy;
}

View File

@ -30,7 +30,7 @@
namespace gcpp {
using KV_t = float;
using KV_t = BF16;
// A non-owning view of a KVCache.
struct KVCachePtr {
@ -38,6 +38,8 @@ struct KVCachePtr {
size_t SeqLen() const;
MatPtrT<KV_t> kv_cache;
MatPtrT<KV_t> k_cache;
MatPtrT<KV_t> v_cache;
};
struct KVCache {
@ -52,10 +54,33 @@ struct KVCache {
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
// The format of k_cache indicates that there are pairs of values from
// qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence,
// in groups of qkv_dim/2 elements in groups of kv_heads elements.
// This enables sequential loading of the data when filling 2 vectors with
// NF sequence elements of pairs of BF16 qkv values. The next vector then
// continues reading the rest of qkv.
// [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2]
MatStorageT<KV_t> k_cache;
// v_cache is formatted to allow sequential access to V during scaling and
// update of att_out.
// Originally [seq_len, layers * kv_heads * qkv_dim]
// v_cache is transposed to:
// [layers, kv_heads, seq_len, qkv_dim], reshaped to:
// [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF]
// then transposed to:
// [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF]
// and finally packed in a 2D MatStorageT as:
// [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF]
// This allows sequential reads of 2NF registers each of 2NF BF16 values,
// repeatedly until all of qkv_dim is read.
MatStorageT<KV_t> v_cache;
KVCachePtr ToPtr() {
return KVCachePtr{
.kv_cache = kv_cache,
.k_cache = k_cache,
.v_cache = v_cache,
};
}

View File

@ -614,267 +614,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
});
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale));
sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale));
sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale));
sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale));
sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale));
sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale));
sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale));
sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
VF& sum9, VF& sum10, VF& sum11,
VF& sum12, VF& sum13, VF& sum14,
VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
VF& sum6, VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8);
sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9);
sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10);
sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11);
sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12);
sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13);
sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14);
sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
VF& sum0, VF& sum1, VF& sum2,
VF& sum3) {
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
}
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
MulAdd4(df, x1, c1, out0, out1, out2, out3);
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
MulAdd4(df, x2, c2, out0, out1, out2, out3);
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
MulAdd4(df, x3, c3, out0, out1, out2, out3);
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
MulAdd4(df, x4, c4, out0, out1, out2, out3);
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
MulAdd4(df, x5, c5, out0, out1, out2, out3);
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
MulAdd4(df, x6, c6, out0, out1, out2, out3);
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
MulAdd4(df, x7, c7, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
}
HWY_DASSERT(size == i);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
const VF c1, const VF c2, const VF c3,
@ -887,240 +626,134 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0,
const VF c1, const VF c2,
const VF c3, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
sum3);
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
sum3);
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
sum3);
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
sum3);
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4(
DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c,
const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a,
VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF c1 = hn::Set(df, *c++);
VF c2 = hn::Set(df, *c++);
VF c3 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a);
MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b);
}
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
sum3);
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
sum3);
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
sum3);
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
sum3);
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
sum3);
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
sum2, sum3);
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
sum2, sum3);
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
sum2, sum3);
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
sum2, sum3);
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
sum2, sum3);
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
sum2, sum3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows
// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out,
// after first prescaling out by scale.
// The depth (size) must be a multiple of NF.
// The depth (size) must be a multiple of 2NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const VF c10, const VF c11, const VF c20, const VF c21, const VF c30,
const VF c31, const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos,
size_t num_lanes, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[8 * kMaxNF];
hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem);
hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF);
size_t i = 0;
while (i + NF <= size) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 8) {
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
if HWY_LANES_CONSTEXPR (NF >= 16) {
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
out3);
}
}
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
i += NF;
while (i + NF * 2 <= size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::Load(df, out + i + out_offsets[0]);
out1a = hn::Load(df, out + i + out_offsets[1]);
out2a = hn::Load(df, out + i + out_offsets[2]);
out3a = hn::Load(df, out + i + out_offsets[3]);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
VF scale3 = hn::Set(df, scales[3]);
out0a = hn::Mul(out0a, scale0);
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out1a, df, out + i + out_offsets[1]);
hn::Store(out2a, df, out + i + out_offsets[2]);
hn::Store(out3a, df, out + i + out_offsets[3]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
HWY_DASSERT(size == i);
}
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
// corresponding values in c0 and adds them to the NF rows of out.
// The depth (size) must be a multiple of NF.
template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df,
const BF16* HWY_RESTRICT v,
const float* HWY_RESTRICT c,
const size_t num_lanes,
VF& sum0a, VF& sum0b) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const size_t kNF = hn::Lanes(df);
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
VBF v0 = hn::Load(dbf, v);
VF c0 = hn::Set(df, *c++);
VF v0a = hn::PromoteLowerTo(df, v0);
VF v0b = hn::PromoteUpperTo(df, v0);
sum0a = hn::MulAdd(v0a, c0, sum0a);
sum0b = hn::MulAdd(v0b, c0, sum0b);
}
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos, size_t num_lanes,
float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df);
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
HWY_DASSERT(pos[0] % (2 * NF) == 0);
HWY_ALIGN float c_mem[2 * kMaxNF];
hn::Store(c00, df, c_mem);
hn::Store(c01, df, c_mem + NF);
size_t i = 0;
while (i + NF <= size) {
if HWY_LANES_CONSTEXPR (NF == 16) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
VF out8, out9, out10, out11, out12, out13, out14, out15;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
out8 = hn::Load(df, out + i + out_offsets[8]);
out9 = hn::Load(df, out + i + out_offsets[9]);
out10 = hn::Load(df, out + i + out_offsets[10]);
out11 = hn::Load(df, out + i + out_offsets[11]);
out12 = hn::Load(df, out + i + out_offsets[12]);
out13 = hn::Load(df, out + i + out_offsets[13]);
out14 = hn::Load(df, out + i + out_offsets[14]);
out15 = hn::Load(df, out + i + out_offsets[15]);
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
out9, out10, out11, out12, out13, out14, out15);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
hn::Store(out8, df, out + i + out_offsets[8]);
hn::Store(out9, df, out + i + out_offsets[9]);
hn::Store(out10, df, out + i + out_offsets[10]);
hn::Store(out11, df, out + i + out_offsets[11]);
hn::Store(out12, df, out + i + out_offsets[12]);
hn::Store(out13, df, out + i + out_offsets[13]);
hn::Store(out14, df, out + i + out_offsets[14]);
hn::Store(out15, df, out + i + out_offsets[15]);
}
if HWY_LANES_CONSTEXPR (NF == 8) {
VF out0, out1, out2, out3, out4, out5, out6, out7;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out4 = hn::Load(df, out + i + out_offsets[4]);
out5 = hn::Load(df, out + i + out_offsets[5]);
out6 = hn::Load(df, out + i + out_offsets[6]);
out7 = hn::Load(df, out + i + out_offsets[7]);
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
hn::Store(out4, df, out + i + out_offsets[4]);
hn::Store(out5, df, out + i + out_offsets[5]);
hn::Store(out6, df, out + i + out_offsets[6]);
hn::Store(out7, df, out + i + out_offsets[7]);
}
if HWY_LANES_CONSTEXPR (NF == 4) {
VF out0, out1, out2, out3;
out0 = hn::Load(df, out + i + out_offsets[0]);
out1 = hn::Load(df, out + i + out_offsets[1]);
out2 = hn::Load(df, out + i + out_offsets[2]);
out3 = hn::Load(df, out + i + out_offsets[3]);
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
VF x0 = hn::Load(df, v.Row(pos) + i);
MulAdd4(df, x0, c0, out0, out1, out2, out3);
hn::Store(out0, df, out + i + out_offsets[0]);
hn::Store(out1, df, out + i + out_offsets[1]);
hn::Store(out2, df, out + i + out_offsets[2]);
hn::Store(out3, df, out + i + out_offsets[3]);
}
i += NF;
while (i + NF * 2 <= size) {
VF out0a, out0b;
out0a = hn::Load(df, out + i + out_offsets[0]);
VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out0b = hn::Mul(out0b, scale0);
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
HWY_DASSERT(size == i);
}

View File

@ -202,6 +202,17 @@ class MatPtr : public IFields {
override_rows_ = static_cast<uint32_t>(rows);
}
// Changes the number of rows and columns without reallocating the memory.
// Increases cols by factor and reduces rows by factor.
// The rows must be divisible by factor and the matrix must be packed.
void ReshapePackedRowsToCols(size_t factor) {
HWY_ASSERT(IsPacked());
HWY_ASSERT(private_rows_ % factor == 0);
private_rows_ /= factor;
cols_ *= factor;
stride_ *= factor;
}
// Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; }

View File

@ -106,7 +106,8 @@ template <typename T>
void FillMatPtrT(MatPtrT<T>& mat) {
for (int i = 0; i < mat.Rows(); ++i) {
for (int j = 0; j < mat.Cols(); ++j) {
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
mat.Row(i)[j] =
hwy::ConvertScalarTo<T>(hwy::Unpredictable1() * 0.01f * (i + j + 1));
}
}
}

View File

@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) {
return "FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionSingleFlashAttention:
return "FlashAttention.SingleFlashAttention";
case Zones::kFlashAttentionTileFlashAttention:
return "FlashAttention.TileFlashAttention";
case Zones::kFlashAttentionTileFlashAttention1:
return "FlashAttention.TileFlashAttention1";
case Zones::kFlashAttentionTileFlashAttention4:
return "FlashAttention.TileFlashAttention4";
case Zones::kFlashAttentionTransposeQ:
return "FlashAttention.TransposeQ";
case Zones::kFlashAttentionTileFlashAttention8:
return "FlashAttention.TileFlashAttention8";
case Zones::kFlashAttentionCombineSplit:
return "FlashAttention.CombineSplit";
case Zones::kGenActivation:
return "Gen.Activation";
case Zones::kGenActivationFused:

View File

@ -14,10 +14,10 @@ enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention1,
kFlashAttentionTileFlashAttention4,
kFlashAttentionTransposeQ,
kFlashAttentionTileFlashAttention8,
kFlashAttentionCombineSplit,
kGenActivation,
kGenActivationFused,
kGenAttention,