mirror of https://github.com/google/gemma.cpp.git
parent
16c1b29b89
commit
a3d994915f
|
|
@ -548,7 +548,6 @@ cc_library(
|
|||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
":flash_structs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
|
|
@ -596,11 +595,6 @@ cc_test(
|
|||
|
||||
INTERNAL_DEPS = []
|
||||
|
||||
cc_library(
|
||||
name = "flash_structs",
|
||||
hdrs = ["gemma/flash_structs.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attention",
|
||||
srcs = [
|
||||
|
|
@ -610,6 +604,7 @@ cc_library(
|
|||
hdrs = [
|
||||
"gemma/attention.h",
|
||||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"gemma/gemma-inl.h",
|
||||
|
|
@ -618,7 +613,6 @@ cc_library(
|
|||
":activations",
|
||||
":basics",
|
||||
":configs",
|
||||
":flash_structs",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/gemma_args.h" // AttentionImpl
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
|
|
@ -53,13 +52,10 @@ struct AttentionActivations {
|
|||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
|
||||
size_t max_workers, const Allocator& allocator,
|
||||
const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: rep_factor(max_workers *
|
||||
AttentionActivations::kThreadReplicationFactor /
|
||||
layer_config.heads),
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention
|
||||
// is still MHA and does not use an external KV cache.
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
// MHA and does not use an external KV cache.
|
||||
q(MatFactory("q", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
|
|
@ -90,9 +86,6 @@ struct AttentionActivations {
|
|||
att_out(MatFactory("att_out", batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
|
||||
allocator)),
|
||||
softmax_d(
|
||||
|
|
@ -114,11 +107,6 @@ struct AttentionActivations {
|
|||
}
|
||||
return;
|
||||
}
|
||||
// This is a guess at the maximum number of params we might need to avoid
|
||||
// reallocations. The actual number of params is determined by the number of
|
||||
// query tiles, which is not known here.
|
||||
flash_params.reserve(batch_size * layer_config.heads);
|
||||
split_flash_params.reserve(batch_size * layer_config.heads);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
|
|
@ -142,10 +130,6 @@ struct AttentionActivations {
|
|||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_out_reps.OverrideRows(batch_size * rep_factor);
|
||||
// There is no override for [split_]flash_params, because we reserved an
|
||||
// upper bound, and flash attention controls the actual size when it
|
||||
// calculates the size and number of tiles.
|
||||
softmax_max.OverrideRows(batch_size);
|
||||
softmax_d.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
|
@ -153,15 +137,6 @@ struct AttentionActivations {
|
|||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
// Maximum factor by which we might scale-up work to maximize parallelism.
|
||||
size_t rep_factor = 1;
|
||||
// Parameters for flash attention. The size of the vector is somewhere between
|
||||
// the number of query rows and 1/8th of that.
|
||||
std::vector<Tile148Params> flash_params;
|
||||
// Parameters for flash attention, split by k-position. May be significantly
|
||||
// larger than flash_params in decode mode, when the number of query rows is
|
||||
// small.
|
||||
std::vector<Tile148Params> split_flash_params;
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<BF16> q_bf;
|
||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
|
@ -173,7 +148,6 @@ struct AttentionActivations {
|
|||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
MatStorageT<float> att_out_reps; // attention output for each thread.
|
||||
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
|
||||
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
|
||||
// Accumulation of attention outputs over heads
|
||||
|
|
@ -190,26 +164,19 @@ struct AttentionActivations {
|
|||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
// Replication factor to help evenly share work over threads.
|
||||
static constexpr size_t kThreadReplicationFactor = 4;
|
||||
};
|
||||
|
||||
// A non-owning view of AttentionActivations.
|
||||
struct AttentionActivationsPtrs {
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||
std::vector<Tile148Params>& flash_params,
|
||||
std::vector<Tile148Params>& split_flash_params)
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
|
||||
: config(config),
|
||||
flash_params(flash_params),
|
||||
split_flash_params(split_flash_params),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
|
||||
query_scale(ChooseQueryScale(config)) {}
|
||||
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||
AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
|
||||
activations.split_flash_params) {
|
||||
const AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len) {
|
||||
q = activations.q;
|
||||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
|
|
@ -219,7 +186,6 @@ struct AttentionActivationsPtrs {
|
|||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
att_out_reps = activations.att_out_reps;
|
||||
softmax_max = activations.softmax_max;
|
||||
softmax_d = activations.softmax_d;
|
||||
att_sums = activations.att_sums;
|
||||
|
|
@ -250,9 +216,6 @@ struct AttentionActivationsPtrs {
|
|||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
// Parameters for flash attention.
|
||||
std::vector<Tile148Params>& flash_params;
|
||||
std::vector<Tile148Params>& split_flash_params;
|
||||
|
||||
// For the matrices below, the batch_size dimension is really qbatch.Size() *
|
||||
// token_batch_size, but in all known uses, one of those is 1. Specifically,
|
||||
|
|
@ -278,7 +241,6 @@ struct AttentionActivationsPtrs {
|
|||
// Attention output computed from att * V, size batch_size x (q_heads *
|
||||
// qkv_dim).
|
||||
MatPtrT<float> att_out;
|
||||
MatPtrT<float> att_out_reps;
|
||||
// The maximum logit value encountered when computing att_out from att,
|
||||
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||
|
|
@ -343,8 +305,7 @@ struct Activations {
|
|||
s_w_linear_w(config.num_layers, max_workers),
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
|
||||
row_ptrs),
|
||||
runtime_config, ctx.allocator, row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -49,39 +49,6 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Returns the number of floats per vector (aka NF).
|
||||
size_t FloatsPerVector() {
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
return hn::Lanes(df);
|
||||
}
|
||||
|
||||
// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
|
||||
// done already, reshape it to take NF into account.
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
|
||||
if (kv.Cols() > cache.Cols()) {
|
||||
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
|
||||
}
|
||||
}
|
||||
|
||||
// Transposes a single row of the kv cache into the k-cache and v-cache.
|
||||
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
|
||||
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
|
||||
// This is inefficient, as the writes are scattered over cache lines, but it
|
||||
// is a tiny fraction of the overall computation, and it is linear in the
|
||||
// token length.
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
for (size_t i = 0; i < qkv_dim; i += 2) {
|
||||
k[i * kFloatsPerTile] = kv[i];
|
||||
k[i * kFloatsPerTile + 1] = kv[i + 1];
|
||||
}
|
||||
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
|
|
@ -313,11 +280,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, env, kv_rows);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
|
||||
}
|
||||
const size_t kFloatsPerVector = FloatsPerVector();
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
|
|
@ -337,26 +299,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
// Note that k_cache and v_cache are different shapes.
|
||||
// The innermost dimension of k is 2 values from qkv_dim because they
|
||||
// are going to be used in a BF16 dot product involving pairs of
|
||||
// values over NF k positions.
|
||||
// The innermost dimension of v is 2NF values from qkv_dim because they
|
||||
// will be loaded into a BF16 vector to be scaled and added to the
|
||||
// cached attention output in 2 NF-sized registers.
|
||||
// TODO(rays): factor out these calculations into functions.
|
||||
auto& k_cache = qbatch.KV(qi).k_cache;
|
||||
KV_t* HWY_RESTRICT k =
|
||||
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2;
|
||||
auto& v_cache = qbatch.KV(qi).v_cache;
|
||||
KV_t* HWY_RESTRICT v =
|
||||
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
|
||||
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
|
||||
kFloatsPerVector +
|
||||
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;
|
||||
|
||||
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -377,10 +319,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
/*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
// This is inefficient, as multiple threads are writing the same K
|
||||
// cache line, but the input is generated by a matmul, so it is
|
||||
// difficult to change, and it probably isn't significant.
|
||||
TransposeKVCacheRow(kv, k, v, qkv_dim);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -403,8 +341,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
} else {
|
||||
// * 2 does not help on Turin.
|
||||
FlashAttention(num_tokens,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
|
||||
AttentionActivations::kThreadReplicationFactor,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||
layer_idx, layer.query_norm_scale, activations, qbatch,
|
||||
env.ctx, attention_impl);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,13 +31,6 @@ namespace gcpp {
|
|||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
size_t FloatsPerVector(); \
|
||||
\
|
||||
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
|
||||
\
|
||||
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
|
||||
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
|
||||
\
|
||||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // strcmp
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
|
@ -107,8 +105,7 @@ struct TestAttentionState {
|
|||
tokens(num_tokens),
|
||||
attention_storage_(model_state.config, model_state.layer_config,
|
||||
batch_size, num_tokens, runtime_config,
|
||||
state.ctx.pools.MaxWorkers(), state.ctx.allocator,
|
||||
row_ptrs_),
|
||||
state.ctx.allocator, row_ptrs_),
|
||||
attention(model_state.config, num_tokens, attention_storage_) {
|
||||
for (size_t i = 0; i < qbatch_size; ++i) {
|
||||
kv_caches.emplace_back(model_state.config, inference_args,
|
||||
|
|
@ -146,7 +143,6 @@ struct TestAttentionState {
|
|||
};
|
||||
|
||||
double GetTolerance() {
|
||||
if (IsBF16<KV_t>()) return 1e-2;
|
||||
const char* target_name = hwy::TargetName(HWY_TARGET);
|
||||
if (strncmp(target_name, "AVX2", 4) == 0) {
|
||||
return 2e-2;
|
||||
|
|
@ -159,20 +155,6 @@ double GetTolerance() {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CompareArraySimilar(const T* expected, const T* actual, size_t count,
|
||||
const char* target_name, const char* filename,
|
||||
int line) {
|
||||
if constexpr (IsBF16<KV_t>()) {
|
||||
constexpr double kTolerance = 3e-2;
|
||||
return hwy::CompareArraySimilar(expected, actual, count, kTolerance,
|
||||
target_name, filename, line);
|
||||
} else {
|
||||
return hwy::CompareArraySimilar(expected, actual, count, GetTolerance(),
|
||||
target_name, filename, line);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kNumTokens, size_t kQBatchSize, size_t kDims>
|
||||
void CompareAttSumsWithGolden(
|
||||
const AttentionActivationsPtrs& attention,
|
||||
|
|
@ -188,9 +170,9 @@ void CompareAttSumsWithGolden(
|
|||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]);
|
||||
}
|
||||
EXPECT_TRUE(CompareArraySimilar(golden[token_idx][qi], actual_row.get(),
|
||||
kDims, hwy::TargetName(HWY_TARGET),
|
||||
__FILE__, __LINE__))
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi;
|
||||
}
|
||||
}
|
||||
|
|
@ -218,20 +200,19 @@ void CompareKVCacheWithGolden(
|
|||
|
||||
for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) {
|
||||
for (size_t qi = 0; qi < kQBatchSize; ++qi) {
|
||||
const BF16* cache_row =
|
||||
const float* cache_row =
|
||||
kv_caches[qi].kv_cache.Row(start_offset + token_idx);
|
||||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_k_row[j] = hwy::ConvertScalarTo<float>(cache_row[kv_offset + j]);
|
||||
actual_v_row[j] =
|
||||
hwy::ConvertScalarTo<float>(cache_row[kv_offset + qkv_dim + j]);
|
||||
actual_k_row[j] = cache_row[kv_offset + j];
|
||||
actual_v_row[j] = cache_row[kv_offset + qkv_dim + j];
|
||||
}
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
k_golden[token_idx][qi], actual_k_row.get(), kDims,
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "K cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
v_golden[token_idx][qi], actual_v_row.get(), kDims,
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "V cache mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " kv_head=" << kv_head;
|
||||
|
|
@ -257,8 +238,8 @@ void CompareQVecsWithGolden(
|
|||
for (size_t j = 0; j < kDims; ++j) {
|
||||
actual_q_row[j] = q_row[head_offset + j];
|
||||
}
|
||||
EXPECT_TRUE(CompareArraySimilar(
|
||||
q_golden[token_idx][qi], actual_q_row.get(), kDims,
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(),
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi
|
||||
<< " q_head=" << q_head;
|
||||
|
|
@ -282,46 +263,46 @@ const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats
|
|||
|
||||
// Layer 0
|
||||
const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
||||
{{46.5, 56.5, 10.0625, 65.5, -2.239375, 135, 15.8125, 51, -100, 52.5,
|
||||
{{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5,
|
||||
26.875, 63, 3.34375, -67.5, 31.125, -190, 125},
|
||||
{-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125,
|
||||
-14.9375, -109.5, 76, 9.25, -142, 29.5, -105}},
|
||||
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 5.35875,
|
||||
{{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875,
|
||||
128, 27.25, -161, 19.125, -58, 97.5},
|
||||
{-17.625, -15.375, 135, -13.4375, -3.343, -45.75, 29.625, 93, 18.625, 75.5,
|
||||
{-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5,
|
||||
102.5, -184, 52.75, 83.5, -71, 46.5, -52}},
|
||||
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.625, -29.125,
|
||||
6.4625, 150, 144, -155, -47.25, -98.5, 3.5625},
|
||||
{-19, -16.75, 129, 0.628925, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
|
||||
-66.5, -0.62165625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.684375, 83, -41.75, 39.5, -203, 110, -76, 131, 1.0069375, -44.5, -63.75,
|
||||
{{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125,
|
||||
6.90625, 150, 144, -155, -47.25, -98.5, 3.5625},
|
||||
{-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51,
|
||||
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
|
||||
-46, -22, -19.375, -16.125, -148, 20.875},
|
||||
{-47, -17.5, 58, 81.5, 23.35, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
10.9375, -52.5, 23.25, 75}},
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 4.213125, -108, 39.25,
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
|
||||
-34.75, 18, -52, 100, -186, -75.5, 50.75},
|
||||
{7.1875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.2675, 82.5,
|
||||
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
|
||||
39.25, 65, 47.25, -89.5, -34.25, 137}},
|
||||
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -12.625, 38.5, 32.5,
|
||||
53.75, 109, 4.62375, 57.5, -20.5, 132},
|
||||
{143, 249, 4.9375, 1.33984375, 27.875, -5.84375, 30.25, -101.5, 65.5, 13.5,
|
||||
195, -10.0625, 97.5, 1.903125, -97.5, -100, -19.25}},
|
||||
{{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5,
|
||||
32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132},
|
||||
{143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5,
|
||||
13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}},
|
||||
{{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25,
|
||||
-117, 26.375, 39.5, 125, 66, 48.75, 20.75},
|
||||
{137, 3.85, 61.25, 37, -42.75, 240, 62, -164, 10.3125, 173, 174, 23.5,
|
||||
88.5, 48.5, -46.25, -35.5, 101.5}},
|
||||
{{-103, -41.5, 39, -52, -62.7, 121, -136, 99, 80, -47.5, 107.5, 43.75, 97.5,
|
||||
125, -53.5, -11.625, 262},
|
||||
{28.075, 6.64375, -36.75, -13.35, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
|
||||
81, -28.5, -3.875, 111, -167, 59, 176}},
|
||||
{137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5,
|
||||
88.5, 48.5, -46.25, -36.75, 101.5}},
|
||||
{{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5,
|
||||
125, -53.5, -14.625, 262},
|
||||
{29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172,
|
||||
81, -27.25, -3.03125, 111, -167, 59, 176}},
|
||||
{{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25,
|
||||
-52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5},
|
||||
{40.25, 53.25, -142, 78.5, 38, 4.625, -27.75, -134, -85, 107.5, 2.5, 93.5,
|
||||
{40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5,
|
||||
58.25, 173, -53.5, 25.125, 4.8125}},
|
||||
{{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5,
|
||||
-17.125, 15.25, 143, -27, 103, 101},
|
||||
{-77, 40.75, -10.5, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
|
||||
8.55, -99.5, 14.6875, -12.25, 33}},
|
||||
{-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625,
|
||||
8.125, -99.5, 13.6875, -11.6875, 33}},
|
||||
};
|
||||
|
||||
// Layer 0, *K*V Head 0
|
||||
|
|
|
|||
|
|
@ -81,8 +81,8 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
}
|
||||
|
||||
enum class AttentionImpl {
|
||||
kOld, // Previous Attention implementation
|
||||
kFlash, // Flash Attention (default)
|
||||
kOld,
|
||||
kFlash,
|
||||
kFlashTransposedQs,
|
||||
kFlashTransposedQsBF16,
|
||||
kSentinel,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -50,6 +50,25 @@ namespace gcpp {
|
|||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
|
||||
const size_t worker); \
|
||||
\
|
||||
void TileFlashAttention( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, \
|
||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \
|
||||
const size_t min_last_pos, const size_t max_last_pos, \
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
|
||||
const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
|
|
@ -73,6 +92,7 @@ namespace gcpp {
|
|||
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
|
||||
float* HWY_RESTRICT max_logits); \
|
||||
\
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -62,17 +62,16 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <typename T>
|
||||
void SetMat(const size_t offset, MatPtrT<T>& mat) {
|
||||
void SetMat(const size_t offset, MatPtrT<float>& mat) {
|
||||
const size_t kOuter = mat.Extents().rows;
|
||||
const size_t kInner = mat.Extents().cols;
|
||||
const float i_scale = 1.0f / kInner;
|
||||
const float j_scale = 1.0f / kOuter;
|
||||
for (size_t i = 0; i < kOuter; ++i) {
|
||||
T* HWY_RESTRICT row = mat.Row(i);
|
||||
float* HWY_RESTRICT row = mat.Row(i);
|
||||
for (size_t j = 0; j < kInner; ++j) {
|
||||
row[j] = hwy::ConvertScalarTo<T>(
|
||||
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale)));
|
||||
row[j] =
|
||||
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -95,15 +94,14 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
|
|||
if (rel_abs_delta > 0.0f) {
|
||||
rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c]));
|
||||
}
|
||||
EXPECT_LT(rel_abs_delta, 1e-3)
|
||||
EXPECT_LT(rel_abs_delta, 1e-5)
|
||||
<< "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << ","
|
||||
<< c << "]=" << b_row[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestFlashAttention(size_t target_parallelism,
|
||||
AttentionImpl attention_impl) {
|
||||
void TestFlashAttention(size_t target_parallelism) {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
constexpr size_t kOuter = 1024;
|
||||
|
|
@ -132,9 +130,9 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
||||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention_storage(
|
||||
config, layer_config, batch_size, kOuter, runtime_config,
|
||||
ctx.pools.MaxWorkers(), ctx.allocator, row_ptrs);
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, runtime_config, ctx.allocator,
|
||||
row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
ASSERT_EQ(qkv_dim, kInner);
|
||||
|
|
@ -144,10 +142,7 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(attention.div_seq_len.GetDivisor());
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
|
||||
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
|
||||
auto& kvc = qbatch.KV(0).kv_cache;
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||
// Make strided views into the kv cache for
|
||||
// this query and head.
|
||||
|
|
@ -158,17 +153,6 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
v.SetPtr(kvc.Row(0) + head_offset + qkv_dim, kvc.Stride());
|
||||
SetMat(h + layer_config.heads, k);
|
||||
SetMat(h + layer_config.heads * 2, v);
|
||||
for (size_t p = 0; p < tokens.size(); ++p) {
|
||||
KV_t* HWY_RESTRICT k_src = k.Row(p);
|
||||
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * 2;
|
||||
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
|
||||
head_offset * kFloatsPerTile / 2 +
|
||||
p % kFloatsPerTile * kFloatsPerTile;
|
||||
|
||||
TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
|
||||
}
|
||||
}
|
||||
SetMat(1, attention.q);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
|
||||
|
|
@ -183,19 +167,18 @@ void TestFlashAttention(size_t target_parallelism,
|
|||
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
|
||||
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
|
||||
total_tasks, target_parallelism);
|
||||
printf("FlashAttention: parallelism=%zu, kNF=%zu, kVTileSize=%zu, mode %s\n",
|
||||
target_parallelism, kNF, kVTileSize,
|
||||
GetAttentionImplName(attention_impl).c_str());
|
||||
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||
target_parallelism, kNF, kVTileSize);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
|
||||
attention, qbatch, ctx, attention_impl);
|
||||
attention, qbatch, ctx, AttentionImpl::kFlash);
|
||||
AssertClose(attention.att_out, *saved_att);
|
||||
ctx.profiler.PrintResults();
|
||||
}
|
||||
|
||||
void TestAttention() {
|
||||
TestFlashAttention(8192, AttentionImpl::kFlash);
|
||||
TestFlashAttention(2048, AttentionImpl::kFlash);
|
||||
TestFlashAttention(256, AttentionImpl::kFlash);
|
||||
TestFlashAttention(8192);
|
||||
TestFlashAttention(2048);
|
||||
TestFlashAttention(256);
|
||||
}
|
||||
|
||||
const std::vector<float> exp_denominator_sums_gold = {
|
||||
|
|
|
|||
|
|
@ -2,19 +2,11 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// The vertical tile size in flash attention when register lanes correspond to
|
||||
// K-timesteps, and the number of registers is 4 for 4 Q-rows.
|
||||
static constexpr size_t k4xNFVTileSize = 4;
|
||||
// The vertical tile size in flash attention when register lanes correspond to
|
||||
// K-timesteps, and the number of registers is 8 for 8 Q-rows.
|
||||
static constexpr size_t k8xNFVTileSize = 8;
|
||||
|
||||
// State for computing softmax in a streaming ("online") manner,
|
||||
// avoiding large intermediate values by subtracting the running maximum.
|
||||
// For a sequence x_1, ..., x_n:
|
||||
|
|
@ -28,46 +20,10 @@ struct OnlineSoftmaxState {
|
|||
float d = 0.0f;
|
||||
};
|
||||
|
||||
struct Tile4FlashState {
|
||||
OnlineSoftmaxState row_states[k8xNFVTileSize];
|
||||
};
|
||||
static constexpr size_t kVTileSize4 = 4;
|
||||
|
||||
// Parameters for a strip of tiles of flash attention. For processing a strip
|
||||
// of tiles, each of 1, k4xNFVTileSize, or k8xNFVTileSize Q-rows, by NF
|
||||
// k-positions. The total width of the strip might cover the entire sequence,
|
||||
// or a part of it, depending on whether the strip has been split.
|
||||
struct Tile148Params {
|
||||
// Vertical tile size gives the number used in the k8xNFVTileSize arrays.
|
||||
// It is the number of Q rows in the tile.
|
||||
uint32_t v_tile_size = 0;
|
||||
// min start position across all rows in the tile determines the
|
||||
// mask used for the tile.
|
||||
uint32_t min_start_pos = std::numeric_limits<uint32_t>::max();
|
||||
// max last position across all rows in the tile determines the mask
|
||||
// used for the tile.
|
||||
uint32_t max_last_pos = 0;
|
||||
// Index into the qbatch.KV is the same for each row in the tile.
|
||||
uint32_t qi_index;
|
||||
// Index into the kv_cache is the same for each row in the tile.
|
||||
uint32_t kv_offset;
|
||||
// In the original task, the index to the split tasks of the first split task.
|
||||
uint32_t split_index = 0;
|
||||
// The index of the split for running split attention.
|
||||
uint32_t i_of_n = 0;
|
||||
// The number of splits for running split attention.
|
||||
uint32_t n_of_n = 0;
|
||||
// Offsets into original Q for each row in the tile.
|
||||
uint32_t q_offsets[k8xNFVTileSize];
|
||||
// Offsets into att_out for each row in the tile.
|
||||
uint32_t out_offsets[k8xNFVTileSize];
|
||||
// Start k-positions for each row in the tile.
|
||||
uint32_t start_pos[k8xNFVTileSize];
|
||||
// Last k-positions for each row in the tile. Inclusive.
|
||||
uint32_t last_pos[k8xNFVTileSize];
|
||||
// Row index to att_out.
|
||||
uint32_t tq_idx[k8xNFVTileSize];
|
||||
// Flash attention state for the tile.
|
||||
Tile4FlashState end_state;
|
||||
struct Tile4FlashState {
|
||||
OnlineSoftmaxState row_states[kVTileSize4];
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -29,11 +29,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// TODO: rays - Remove this once hwy is updated.
|
||||
#ifndef HWY_ARCH_MAX_BYTES
|
||||
#define HWY_ARCH_MAX_BYTES 256
|
||||
#endif
|
||||
|
||||
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||
// the total number of elements can exceed 2^32.
|
||||
static size_t CappedSeqLen(const ModelConfig& config,
|
||||
|
|
@ -48,23 +43,6 @@ static size_t CappedSeqLen(const ModelConfig& config,
|
|||
|
||||
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
|
||||
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
// WARNING: the rows and cols of k_cache and v_cache will be modified
|
||||
// before use!
|
||||
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
|
||||
// cols will be increased by 2xkFloatsPerVector on first use. This is to
|
||||
// avoid making KVCache another class that has to be duplicated for each
|
||||
// machine architecture, since kFloatsPerVector is architecture dependent.
|
||||
// The change is shape is safe only if the padding is kPacked.
|
||||
k_cache("k",
|
||||
Extents2D(HWY_MAX(kv_extents.rows,
|
||||
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
|
||||
kv_extents.cols / 2),
|
||||
allocator, MatPadding::kPacked),
|
||||
v_cache("v",
|
||||
Extents2D(HWY_MAX(kv_extents.rows,
|
||||
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
|
||||
kv_extents.cols / 2),
|
||||
allocator, MatPadding::kPacked),
|
||||
allocator_(allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
|
|
@ -79,16 +57,14 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
: allocator_(allocator) {
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
|
||||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| ((runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs
|
||||
) &&
|
||||
hwy::IsSame<KV_t, BF16>())) {
|
||||
) {
|
||||
const size_t num_tiles =
|
||||
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
|
||||
tiled_seq_len = num_tiles * kTileSize;
|
||||
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
|
||||
Type kv_cache_type;
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| hwy::IsSame<KV_t, BF16>()) {
|
||||
) {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
|
||||
} else {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
|
||||
|
|
@ -138,6 +114,9 @@ KVCache KVCache::Copy() {
|
|||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr);
|
||||
copy.tiled_seq_len = tiled_seq_len;
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
using KV_t = BF16;
|
||||
using KV_t = float;
|
||||
struct KVCache;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
|
|
@ -40,8 +40,6 @@ struct KVCachePtr {
|
|||
|
||||
bool IsTiled() const;
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
MatPtrT<KV_t> k_cache;
|
||||
MatPtrT<KV_t> v_cache;
|
||||
KVCache* cache = nullptr;
|
||||
};
|
||||
|
||||
|
|
@ -125,33 +123,11 @@ struct KVCache {
|
|||
// kv_head_ptrs[...].Rows().
|
||||
std::vector<MatPtr> kv_head_ptrs;
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
// The format of k_cache indicates that there are pairs of values from
|
||||
// qkv_dim in groups of 2x kFloatsPerVector(=NF) elements from the sequence,
|
||||
// in groups of qkv_dim/2 elements in groups of kv_heads elements.
|
||||
// This enables sequential loading of the data when filling 2 vectors with
|
||||
// NF sequence elements of pairs of BF16 qkv values. The next vector then
|
||||
// continues reading the rest of qkv.
|
||||
// [seq_len / 2NF, layers * kv_heads * qkv_dim/2 * 2NF * 2]
|
||||
MatStorageT<KV_t> k_cache;
|
||||
// v_cache is formatted to allow sequential access to V during scaling and
|
||||
// update of att_out.
|
||||
// Originally [seq_len, layers * kv_heads * qkv_dim]
|
||||
// v_cache is transposed to:
|
||||
// [layers, kv_heads, seq_len, qkv_dim], reshaped to:
|
||||
// [layers, kv_heads, seq_len/(2NF), 2NF, qkv_dim/(2NF), 2NF]
|
||||
// then transposed to:
|
||||
// [seq_len/(2NF), layers, kv_heads, qkv_dim/(2NF), 2NF, 2NF]
|
||||
// and finally packed in a 2D MatStorageT as:
|
||||
// [seq_len/(2NF), layers * kv_heads * qkv_dim/(2NF) * 2NF * 2NF]
|
||||
// This allows sequential reads of 2NF registers each of 2NF BF16 values,
|
||||
// repeatedly until all of qkv_dim is read.
|
||||
MatStorageT<KV_t> v_cache;
|
||||
|
||||
KVCachePtr ToPtr() {
|
||||
return KVCachePtr{
|
||||
.kv_cache = kv_cache,
|
||||
.k_cache = k_cache,
|
||||
.v_cache = v_cache,
|
||||
.cache = this,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ struct AttentionTestEnv {
|
|||
}
|
||||
}
|
||||
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
|
||||
MatPtrT<KV_t> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
MatPtrT<float> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
FillMatPtrT(compact_kv_cache);
|
||||
} else {
|
||||
FillMatPtrT(kv_caches.back().kv_cache);
|
||||
|
|
@ -735,13 +735,12 @@ HWY_AFTER_NAMESPACE();
|
|||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(TiledAttentionTest);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries);
|
||||
// TODO() Fix the goldens for the change in KV_t to BF16
|
||||
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
// TestLocalAttentionForAllHeadsTokensAndBatch);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
TestLocalAttentionForAllHeadsTokensAndBatch);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16);
|
||||
// HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
// TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
|
||||
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
|
|
|
|||
615
ops/ops-inl.h
615
ops/ops-inl.h
|
|
@ -613,6 +613,267 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
|
|||
});
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
|
||||
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
|
||||
VF& sum9, VF& sum10, VF& sum11,
|
||||
VF& sum12, VF& sum13, VF& sum14,
|
||||
VF& sum15) {
|
||||
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
|
||||
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
|
||||
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
|
||||
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
|
||||
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
|
||||
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
|
||||
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
|
||||
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
|
||||
sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale));
|
||||
sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale));
|
||||
sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale));
|
||||
sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale));
|
||||
sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale));
|
||||
sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale));
|
||||
sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale));
|
||||
sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4,
|
||||
VF& sum5, VF& sum6, VF& sum7, VF& sum8,
|
||||
VF& sum9, VF& sum10, VF& sum11,
|
||||
VF& sum12, VF& sum13, VF& sum14,
|
||||
VF& sum15) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
|
||||
VF& sum6, VF& sum7) {
|
||||
sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale));
|
||||
sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale));
|
||||
sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale));
|
||||
sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale));
|
||||
sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale));
|
||||
sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale));
|
||||
sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale));
|
||||
sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5,
|
||||
VF& sum6, VF& sum7) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
|
||||
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
|
||||
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
|
||||
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
|
||||
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
|
||||
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
|
||||
sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8);
|
||||
sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9);
|
||||
sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10);
|
||||
sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11);
|
||||
sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12);
|
||||
sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13);
|
||||
sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14);
|
||||
sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16(
|
||||
DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9,
|
||||
VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4);
|
||||
sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5);
|
||||
sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6);
|
||||
sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split,
|
||||
VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3) {
|
||||
sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0);
|
||||
sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1);
|
||||
sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2);
|
||||
sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3);
|
||||
}
|
||||
|
||||
// For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows
|
||||
// of V by the corresponding values in c0-c7 and adds them to NF rows of out,
|
||||
// after first prescaling out by scale.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
||||
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd16(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd16(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd16(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd16(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd16(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd16(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd16(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd8(df, x1, c1, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd8(df, x2, c2, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd8(df, x3, c3, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd8(df, x4, c4, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd8(df, x5, c5, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd8(df, x6, c6, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd8(df, x7, c7, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + i);
|
||||
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + i);
|
||||
MulAdd4(df, x1, c1, out0, out1, out2, out3);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + i);
|
||||
MulAdd4(df, x2, c2, out0, out1, out2, out3);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + i);
|
||||
MulAdd4(df, x3, c3, out0, out1, out2, out3);
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + i);
|
||||
MulAdd4(df, x4, c4, out0, out1, out2, out3);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + i);
|
||||
MulAdd4(df, x5, c5, out0, out1, out2, out3);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + i);
|
||||
MulAdd4(df, x6, c6, out0, out1, out2, out3);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + i);
|
||||
MulAdd4(df, x7, c7, out0, out1, out2, out3);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
}
|
||||
i += NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
|
||||
const VF c1, const VF c2, const VF c3,
|
||||
|
|
@ -625,147 +886,145 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
|
|||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT4(
|
||||
DF df, const BF16* HWY_RESTRICT v, const float* HWY_RESTRICT c,
|
||||
const size_t num_lanes, VF& sum0a, VF& sum1a, VF& sum2a, VF& sum3a,
|
||||
VF& sum0b, VF& sum1b, VF& sum2b, VF& sum3b) {
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
|
||||
VBF v0 = hn::Load(dbf, v);
|
||||
VF c0 = hn::Set(df, *c++);
|
||||
VF c1 = hn::Set(df, *c++);
|
||||
VF c2 = hn::Set(df, *c++);
|
||||
VF c3 = hn::Set(df, *c++);
|
||||
VF v0a = hn::PromoteLowerTo(df, v0);
|
||||
VF v0b = hn::PromoteUpperTo(df, v0);
|
||||
MulAdd4(df, v0a, c0, c1, c2, c3, sum0a, sum1a, sum2a, sum3a);
|
||||
MulAdd4(df, v0b, c0, c1, c2, c3, sum0b, sum1b, sum2b, sum3b);
|
||||
}
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0,
|
||||
const VF c1, const VF c2,
|
||||
const VF c3, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
|
||||
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
|
||||
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
|
||||
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
|
||||
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
|
||||
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
|
||||
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
|
||||
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
|
||||
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
}
|
||||
|
||||
// For a 2NFx4 tile of float values in 8xNF-lane registers, multiplies 2NF rows
|
||||
// of V by the corresponding values in c00-c31 and adds them to 2NF rows of out,
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
|
||||
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
|
||||
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
|
||||
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
|
||||
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
|
||||
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
|
||||
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
|
||||
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
|
||||
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
|
||||
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
|
||||
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
|
||||
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
|
||||
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
|
||||
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
|
||||
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
|
||||
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
|
||||
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
|
||||
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
|
||||
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
|
||||
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
|
||||
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
|
||||
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
|
||||
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
|
||||
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
|
||||
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
|
||||
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
|
||||
// after first prescaling out by scale.
|
||||
// The depth (size) must be a multiple of 2NF.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
|
||||
const VF c10, const VF c11, const VF c20, const VF c21, const VF c30,
|
||||
const VF c31, const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos,
|
||||
size_t num_lanes, float* HWY_RESTRICT out,
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
|
||||
const VF c2, const VF c3, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
|
||||
HWY_DASSERT(pos[0] % (2 * NF) == 0);
|
||||
HWY_ALIGN float c_mem[8 * kMaxNF];
|
||||
hn::StoreInterleaved4(c00, c10, c20, c30, df, c_mem);
|
||||
hn::StoreInterleaved4(c01, c11, c21, c31, df, c_mem + 4 * NF);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1a = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2a = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3a = hn::Load(df, out + i + out_offsets[3]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
VF scale1 = hn::Set(df, scales[1]);
|
||||
VF scale2 = hn::Set(df, scales[2]);
|
||||
VF scale3 = hn::Set(df, scales[3]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out1a = hn::Mul(out1a, scale1);
|
||||
out2a = hn::Mul(out2a, scale2);
|
||||
out3a = hn::Mul(out3a, scale3);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
|
||||
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
|
||||
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
out1b = hn::Mul(out1b, scale1);
|
||||
out2b = hn::Mul(out2b, scale2);
|
||||
out3b = hn::Mul(out3b, scale3);
|
||||
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
|
||||
out2a, out3a, out0b, out1b, out2b, out3b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1a, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2a, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3a, df, out + i + out_offsets[3]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
|
||||
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
|
||||
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
while (i + NF <= size) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
|
||||
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
|
||||
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
|
||||
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
|
||||
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 8) {
|
||||
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 16) {
|
||||
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
|
||||
out3);
|
||||
}
|
||||
}
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
i += NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddNLanesVT1(DF df,
|
||||
const BF16* HWY_RESTRICT v,
|
||||
const float* HWY_RESTRICT c,
|
||||
const size_t num_lanes,
|
||||
VF& sum0a, VF& sum0b) {
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
for (size_t lane = 0; lane < num_lanes; ++lane, v += 2 * kNF) {
|
||||
VBF v0 = hn::Load(dbf, v);
|
||||
VF c0 = hn::Set(df, *c++);
|
||||
VF v0a = hn::PromoteLowerTo(df, v0);
|
||||
VF v0b = hn::PromoteUpperTo(df, v0);
|
||||
sum0a = hn::MulAdd(v0a, c0, sum0a);
|
||||
sum0b = hn::MulAdd(v0b, c0, sum0b);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c00, const VF c01,
|
||||
const MatPtrT<BF16>& v, const size_t* HWY_RESTRICT pos, size_t num_lanes,
|
||||
float* HWY_RESTRICT out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
const BF16* HWY_RESTRICT v_bf = v.Row(pos[0] / (2 * NF));
|
||||
HWY_DASSERT(pos[0] % (2 * NF) == 0);
|
||||
HWY_ALIGN float c_mem[2 * kMaxNF];
|
||||
hn::Store(c00, df, c_mem);
|
||||
hn::Store(c01, df, c_mem + NF);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out0b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
while (i < size) {
|
||||
float sum = out[i + out_offsets[0]] * scales[0];
|
||||
const BF16* HWY_RESTRICT v_local = v_bf;
|
||||
for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF);
|
||||
++lane, v_local += 2 * NF) {
|
||||
sum += hwy::ConvertScalarTo<float>(*v_local) * c_mem[lane];
|
||||
}
|
||||
++i;
|
||||
++v_bf;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t N, typename DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT<float>& out,
|
||||
size_t start_col, VF out0_0, VF out0_1,
|
||||
|
|
@ -1211,6 +1470,104 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
|
|||
HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
|
||||
// corresponding values in c0 and adds them to the NF rows of out.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
|
||||
const size_t pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
out8 = hn::Load(df, out + i + out_offsets[8]);
|
||||
out9 = hn::Load(df, out + i + out_offsets[9]);
|
||||
out10 = hn::Load(df, out + i + out_offsets[10]);
|
||||
out11 = hn::Load(df, out + i + out_offsets[11]);
|
||||
out12 = hn::Load(df, out + i + out_offsets[12]);
|
||||
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
hn::Store(out8, df, out + i + out_offsets[8]);
|
||||
hn::Store(out9, df, out + i + out_offsets[9]);
|
||||
hn::Store(out10, df, out + i + out_offsets[10]);
|
||||
hn::Store(out11, df, out + i + out_offsets[11]);
|
||||
hn::Store(out12, df, out + i + out_offsets[12]);
|
||||
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out4 = hn::Load(df, out + i + out_offsets[4]);
|
||||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
hn::Store(out4, df, out + i + out_offsets[4]);
|
||||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd4(df, x0, c0, out0, out1, out2, out3);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
}
|
||||
i += NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
// TODO: support bf16 logits using Decompress2.
|
||||
// Computes softmax probabilities for the given logits, normalizing in-place.
|
||||
|
|
|
|||
10
util/mat.h
10
util/mat.h
|
|
@ -202,16 +202,6 @@ class MatPtr : public IFields {
|
|||
override_rows_ = static_cast<uint32_t>(rows);
|
||||
}
|
||||
|
||||
// Changes the number of rows and columns without reallocating the memory.
|
||||
// Increases cols by factor and reduces rows by factor.
|
||||
// The rows must be divisible by factor and the matrix must be packed.
|
||||
void ReshapePackedRowsToCols(size_t factor) {
|
||||
HWY_ASSERT(IsPacked());
|
||||
private_rows_ = hwy::DivCeil(private_rows_, factor);
|
||||
cols_ *= factor;
|
||||
stride_ *= factor;
|
||||
}
|
||||
|
||||
// Offset by which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -106,8 +106,7 @@ template <typename T>
|
|||
void FillMatPtrT(MatPtrT<T>& mat) {
|
||||
for (int i = 0; i < mat.Rows(); ++i) {
|
||||
for (int j = 0; j < mat.Cols(); ++j) {
|
||||
mat.Row(i)[j] =
|
||||
hwy::ConvertScalarTo<T>(hwy::Unpredictable1() * 0.01f * (i + j + 1));
|
||||
mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ const char* ZoneName(Zones zone) {
|
|||
return "FlashAttention.Inclusive";
|
||||
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
|
||||
return "FlashAttention.RMSNormAndPositionalEncoding";
|
||||
case Zones::kFlashAttentionTileFlashAttention1:
|
||||
return "FlashAttention.TileFlashAttention1";
|
||||
case Zones::kFlashAttentionSingleFlashAttention:
|
||||
return "FlashAttention.SingleFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention:
|
||||
return "FlashAttention.TileFlashAttention";
|
||||
case Zones::kFlashAttentionTileFlashAttention4:
|
||||
return "FlashAttention.TileFlashAttention4";
|
||||
case Zones::kFlashAttentionTileFlashAttention8:
|
||||
return "FlashAttention.TileFlashAttention8";
|
||||
case Zones::kFlashAttentionCombineSplit:
|
||||
return "FlashAttention.CombineSplit";
|
||||
case Zones::kFlashAttentionTransposeQ:
|
||||
return "FlashAttention.TransposeQ";
|
||||
case Zones::kGenActivation:
|
||||
return "Gen.Activation";
|
||||
case Zones::kGenActivationFused:
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ enum class Zones { // Keep sorted
|
|||
kFlashAttentionFlashAttention,
|
||||
kFlashAttentionInclusive,
|
||||
kFlashAttentionRmsNormAndPositionalEncoding,
|
||||
kFlashAttentionTileFlashAttention1,
|
||||
kFlashAttentionSingleFlashAttention,
|
||||
kFlashAttentionTileFlashAttention,
|
||||
kFlashAttentionTileFlashAttention4,
|
||||
kFlashAttentionTileFlashAttention8,
|
||||
kFlashAttentionCombineSplit,
|
||||
kFlashAttentionTransposeQ,
|
||||
kGenActivation,
|
||||
kGenActivationFused,
|
||||
kGenAttention,
|
||||
|
|
|
|||
Loading…
Reference in New Issue