mirror of https://github.com/google/gemma.cpp.git
Added a smaller tile size to flash attention for smaller batch sizes
PiperOrigin-RevId: 813226193
This commit is contained in:
parent
4974f24832
commit
2f6cbde8ff
|
|
@ -548,6 +548,7 @@ cc_library(
|
|||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
":matmul_env",
|
||||
":model_store",
|
||||
":ops",
|
||||
|
|
|
|||
|
|
@ -358,7 +358,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx);
|
||||
} else {
|
||||
FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx);
|
||||
FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer,
|
||||
activations, qbatch, env.ctx);
|
||||
}
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/attention.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -114,6 +115,27 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
}
|
||||
}
|
||||
|
||||
// Handles a single v row of flash attention for a single q.k dot product.
|
||||
void HWY_INLINE SingleFlashAttentionStep(
|
||||
float x, float cap, float& old_max, float& old_d,
|
||||
const float* HWY_RESTRICT v, const size_t v_cols,
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||
if (cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
x = cap * std::tanh(x / cap);
|
||||
}
|
||||
float m = std::max(x, old_max);
|
||||
x = std::exp(x - m);
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
old_d = x + scale;
|
||||
old_max = m;
|
||||
float one_over_d = 1.0f / old_d;
|
||||
scale *= one_over_d;
|
||||
x *= one_over_d;
|
||||
MulByConst(scale, att_out, v_cols, p, worker);
|
||||
MulByConstAndAdd(x, v, att_out, v_cols, p, worker);
|
||||
}
|
||||
|
||||
// Calculates the complete attention outputs for a single row of q.
|
||||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
|
|
@ -136,21 +158,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
x = activations.config.att_cap *
|
||||
std::tanh(x / activations.config.att_cap);
|
||||
}
|
||||
float m_new = std::max(m, x);
|
||||
float scale = d * std::exp(m - m_new);
|
||||
x = std::exp(x - m_new);
|
||||
m = m_new;
|
||||
d = scale + x;
|
||||
float one_over_d = 1.0f / d;
|
||||
x *= one_over_d;
|
||||
scale *= one_over_d;
|
||||
MulByConst(scale, att_out, v.Cols(), p, worker);
|
||||
MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||
v.Row(pos_mod), v.Cols(), att_out, p, worker);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,7 +176,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
|||
return hn::LoadU(df, results);
|
||||
}
|
||||
|
||||
// Returns an 8xNF tile of Q.K dot products, in single precision.
|
||||
// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
|
||||
// precision.
|
||||
// This is the result of NF rows of Q against 8 K timesteps, with positions
|
||||
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
|
||||
// consecutive elements, and other columns by adding q_stride.
|
||||
|
|
@ -240,8 +250,9 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
|||
return hn::Add(sum0, sum2);
|
||||
}
|
||||
|
||||
// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then
|
||||
// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos].
|
||||
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||
|
|
@ -260,7 +271,7 @@ void TileFlashAttention(
|
|||
using DI = hn::ScalableTag<uint32_t>;
|
||||
const DI di;
|
||||
using VI = hn::Vec<DI>;
|
||||
const int kVTileSize = hn::MaxLanes(df);
|
||||
const int kVTileSize = hn::Lanes(df);
|
||||
for (int i = 0; i < kVTileSize; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||
|
|
@ -348,38 +359,217 @@ void TileFlashAttention(
|
|||
}
|
||||
}
|
||||
|
||||
// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision.
|
||||
// This is the result of 4 rows of Q against NF K timesteps, with positions
|
||||
// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read
|
||||
// in consecutive elements, and other columns by adding q_stride.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||
const MatPtrT<KV_t>& k, const int32_t* HWY_RESTRICT k_offsets,
|
||||
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
sum0 = hn::Zero(df);
|
||||
sum1 = hn::Zero(df);
|
||||
sum2 = hn::Zero(df);
|
||||
sum3 = hn::Zero(df);
|
||||
const float* HWY_RESTRICT k_base = k.Row(0);
|
||||
using DI = hn::ScalableTag<int32_t>;
|
||||
const DI di;
|
||||
using VI = hn::Vec<DI>;
|
||||
VI k_offsets_vec = hn::LoadU(di, k_offsets);
|
||||
for (size_t i = 0; i < k.Cols(); ++i) {
|
||||
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
|
||||
VF q_0 = hn::Set(df, q[0]);
|
||||
sum0 = hn::MulAdd(q_0, k_vec, sum0);
|
||||
VF q_1 = hn::Set(df, q[1]);
|
||||
sum1 = hn::MulAdd(q_1, k_vec, sum1);
|
||||
VF q_2 = hn::Set(df, q[2]);
|
||||
sum2 = hn::MulAdd(q_2, k_vec, sum2);
|
||||
VF q_3 = hn::Set(df, q[3]);
|
||||
sum3 = hn::MulAdd(q_3, k_vec, sum3);
|
||||
q += q_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Handles NF v rows of flash attention for NF q.k dot products from one q row.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||
float& old_d) {
|
||||
float m = hn::ReduceMax(df, x);
|
||||
m = std::max(m, old_max);
|
||||
x = hn::Exp(df, x - hn::Set(df, m));
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
old_d = hn::ReduceSum(df, x) + scale;
|
||||
old_max = m;
|
||||
float one_over_d = 1.0f / old_d;
|
||||
scale *= one_over_d;
|
||||
x = hn::Mul(x, hn::Set(df, one_over_d));
|
||||
return scale;
|
||||
}
|
||||
|
||||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention4(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
const size_t kHTileSize = hn::Lanes(df);
|
||||
HWY_DASSERT(kHTileSize <= kMaxNF);
|
||||
constexpr size_t kVTileSize = 4;
|
||||
float scales[kVTileSize];
|
||||
for (size_t i = 0; i < kVTileSize; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
|
||||
v.Cols() * sizeof(att_out.Row(0)[0]));
|
||||
}
|
||||
float old_m0 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m1 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m2 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_m3 = -std::numeric_limits<float>::max() / 2.0f;
|
||||
float old_d0 = 0.0f;
|
||||
float old_d1 = 0.0f;
|
||||
float old_d2 = 0.0f;
|
||||
float old_d3 = 0.0f;
|
||||
const float* HWY_RESTRICT qT_row = qT.Row(0);
|
||||
const size_t qT_stride = qT.Stride();
|
||||
size_t position = start_pos;
|
||||
while (position + kHTileSize - 1 <= min_last_pos) {
|
||||
int32_t k_offsets[kMaxNF];
|
||||
size_t v_pos[kMaxNF];
|
||||
for (size_t i = 0; i < kHTileSize; ++i) {
|
||||
v_pos[i] = activations.div_seq_len.Remainder(position + i);
|
||||
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
|
||||
}
|
||||
VF x0, x1, x2, x3;
|
||||
QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3);
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||
VF cap = hn::Set(df, activations.config.att_cap);
|
||||
VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap);
|
||||
x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap)));
|
||||
x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap)));
|
||||
x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap)));
|
||||
x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap)));
|
||||
}
|
||||
scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0);
|
||||
scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1);
|
||||
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
|
||||
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
|
||||
out_offsets, v.Cols(), p, worker);
|
||||
position += kHTileSize;
|
||||
}
|
||||
while (position <= max_last_pos) {
|
||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||
if (position <= last_pos[0]) {
|
||||
// Past the last position, x0 doesn't count.
|
||||
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[0], p, worker);
|
||||
}
|
||||
if (position <= last_pos[1]) {
|
||||
// Past the last position, x1 doesn't count.
|
||||
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[1], p, worker);
|
||||
}
|
||||
if (position <= last_pos[2]) {
|
||||
// Past the last position, x2 doesn't count.
|
||||
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[2], p, worker);
|
||||
}
|
||||
if (position <= last_pos[3]) {
|
||||
// Past the last position, x3 doesn't count.
|
||||
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[3], p, worker);
|
||||
}
|
||||
++position;
|
||||
}
|
||||
}
|
||||
|
||||
// Rounds n to a number that can be used as the number of Q rows in a tile
|
||||
// of flash attention.
|
||||
static size_t RoundToSuitablePowerOf2(size_t n) {
|
||||
if (n < 4) return 1;
|
||||
if (n < 8) return 4;
|
||||
if (n < 16) return 8;
|
||||
if (n < 32) return 16;
|
||||
return 32;
|
||||
}
|
||||
|
||||
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
|
||||
// into a single output O[L,D].
|
||||
// Conventional attention first computes A[L,L] = Q . KT
|
||||
// followed by A = softmax(A) (over invididual rows).
|
||||
// Then A is multiplied by V to get O[L,D].
|
||||
// For each row of O, this takes a read of one row of Q L times, all of K,
|
||||
// 3 write/reads of one row of A, read all of V, an read.write the one row of O
|
||||
// 3 write/reads of one row of A, read all of V, and read/write the one row of O
|
||||
// L times. Ignoring the computation for now, and focusing just on memory,
|
||||
// the one row of O takes L(4D+3) reads and L(D+3) writes.
|
||||
// For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes.
|
||||
//
|
||||
// Flash attention fuses these operations together, and (where possible)
|
||||
// computes NF rows of the result using 8 accumulator registers and two more to
|
||||
// keep running results. NF is the number of float lanes in a register, being 16
|
||||
// for AVX3. The softmax is converted to streaming form using the
|
||||
// algortihm from:
|
||||
// Flash attention fuses these operations together, and has 3 operating modes:
|
||||
// 1. NF rows of the result computed using tiles of registers of shape NFx8.
|
||||
// 2. 4 rows of the result computed using tiles of registers of shape 4xNF.
|
||||
// 3. One row (of Q and the result) at a time.
|
||||
// In all cases the intermediate result (Q.KT) is never stored to memory.
|
||||
// NF is the number of float lanes in a register, being 16 for AVX3. The softmax
|
||||
// is converted to streaming form using the algorithm from:
|
||||
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf.
|
||||
// Q is transposed to Q_T[D,L] to make the dot product computation efficient.
|
||||
// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing
|
||||
// reads of Q by 8 and reads of K by NF. The streaming softmax is computed
|
||||
// entirely in registers, and a further NF registers to accumulate the results
|
||||
// of the product of the softmax and V, reduce the number of reads of V by NF,
|
||||
// and the reads/writes of O by 8.
|
||||
//
|
||||
// In mode 1:
|
||||
// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one
|
||||
// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is
|
||||
// computed entirely in registers, and a further NF registers to accumulate the
|
||||
// results of the product of the softmax and V, reduce the number of reads of V
|
||||
// by NF, and the reads/writes of O by 8.
|
||||
// The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8,
|
||||
// which on AVX3 is an overall reduction by about a factor of 10.
|
||||
// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn
|
||||
// prefill, since in other cases, there is either a single K timestep (prefill)
|
||||
// or a single num_heads set of Q rows (decode).
|
||||
//
|
||||
// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile,
|
||||
// reducing the reads of Q by NF, and the reads of K by 4. The softmax and
|
||||
// accumulation of the result is done in registers, cutting the reads of V by 4.
|
||||
// The reads/writes of O are reduced by a factor of NF.
|
||||
// The overall reduction is limited by the need to use gather to load K.
|
||||
// Transposing K would be possible, but is complicated by the wraparound.
|
||||
// Mode 2 can be used in all cases when there are at least 4 attention heads,
|
||||
// but it may be prefereable to use mode 3 when the batch size is small to
|
||||
// maximise parallelism.
|
||||
//
|
||||
// In mode 3, a single row of Q is computed against a single K timestep at a
|
||||
// time, using SingleFlashAttention. In this case there is no reduction in the
|
||||
// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are
|
||||
// still eliminated.
|
||||
//
|
||||
// A further complication is that real attention is not as simple as documented
|
||||
// in the paper and above. There are multiple query heads, differing KV, and
|
||||
// different sequence lengths, so a lot of the work in FlashAttention is making
|
||||
// sure that a collection of q rows can use the TileFlashAttention path.
|
||||
void FlashAttention(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
// sure that a collection of q rows with the same KV and sequence length are
|
||||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||
// 3 modes to use for best efficiency.
|
||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention");
|
||||
|
|
@ -392,15 +582,28 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
|
|||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
constexpr size_t kVTileSize = hn::MaxLanes(df);
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
|
||||
const size_t total_tasks = token_batch * layer_config.heads;
|
||||
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
constexpr size_t kMaxNF = hn::MaxLanes(df);
|
||||
HWY_DASSERT(kNF <= kMaxNF);
|
||||
// The vertical tile size is determined by the ability to use tiling and the
|
||||
// target_parallelism. In practice the possible tile sizes in order of
|
||||
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
|
||||
// 16. The final tile size is chosen to be the largest possible that allows
|
||||
// for target_parallelism parallel tasks.
|
||||
const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens);
|
||||
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
|
||||
const size_t kVTileSize =
|
||||
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
|
||||
? kNF
|
||||
: std::min(kMinTileSize, kMaxEqualK);
|
||||
// q has shape [batch, qbatch][head, qkv_dim].
|
||||
// We transpose it to [qkv_dim][qbatch, head, batch] in order to make the
|
||||
// maximum possible number of consecutive columns have the same KV matrices.
|
||||
|
|
@ -416,26 +619,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
|
|||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
// Offsets into original Q for each row in the tile.
|
||||
uint32_t q_offsets[kVTileSize];
|
||||
uint32_t q_offsets[kMaxNF];
|
||||
// Offsets into att_out for each row in the tile.
|
||||
uint32_t out_offsets[kVTileSize];
|
||||
uint32_t out_offsets[kMaxNF];
|
||||
// Start positions for each row in the tile.
|
||||
size_t start_positions[kVTileSize];
|
||||
size_t start_positions[kMaxNF];
|
||||
// Last positions for each row in the tile. Inclusive.
|
||||
uint32_t last_pos[kVTileSize];
|
||||
uint32_t last_pos[kMaxNF];
|
||||
// min and max last positions across all rows in the tile determines when
|
||||
// TileFlashAttention switches to single vector mode to handle the
|
||||
// ragged sequence lengths.
|
||||
size_t min_last_pos = std::numeric_limits<size_t>::max();
|
||||
size_t max_last_pos = 0;
|
||||
// Indices into the qbatch.KV for each row in the tile.
|
||||
size_t qi_indices[kVTileSize];
|
||||
size_t qi_indices[kMaxNF];
|
||||
// Indices into the kv_cache for each row in the tile.
|
||||
size_t kv_offsets[kVTileSize];
|
||||
size_t kv_offsets[kMaxNF];
|
||||
// first_task is [qbatch, head, token].
|
||||
const size_t first_task = task * kVTileSize;
|
||||
const size_t last_task = first_task + kVTileSize - 1;
|
||||
bool use_tile_attention = last_task < total_tasks;
|
||||
bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks;
|
||||
for (size_t offset = 0;
|
||||
offset < kVTileSize && first_task + offset < total_tasks; ++offset) {
|
||||
const size_t batch_idx = div_tokens.Remainder(first_task + offset);
|
||||
|
|
@ -486,15 +689,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx,
|
|||
kv_cache.Stride());
|
||||
if (use_tile_attention) {
|
||||
// To avoid duplicating the code to setup K and V, the call to
|
||||
// TileFlashAttention is inside the loop over tasks, even thought it
|
||||
// TileFlashAttention is inside the loop over tasks, even though it
|
||||
// handles all rows in the task at once.
|
||||
StridedView<float> qT =
|
||||
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
|
||||
activations.q_T.Stride());
|
||||
TileFlashAttention(
|
||||
activations.q, q_offsets, qT, k, start_positions[offset], last_pos,
|
||||
min_last_pos, max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx.profiler, worker);
|
||||
if (kVTileSize == kNF) {
|
||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx.profiler,
|
||||
worker);
|
||||
} else if (kVTileSize == 4) {
|
||||
TileFlashAttention4(activations.q, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
activations.att_out, out_offsets, ctx.profiler,
|
||||
worker);
|
||||
} else {
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ namespace gcpp {
|
|||
float* HWY_RESTRICT att_out, hwy::Profiler& p, \
|
||||
size_t worker); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
|
|
|
|||
|
|
@ -98,13 +98,14 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
|
|||
}
|
||||
}
|
||||
|
||||
void TestAttention() {
|
||||
void TestFlashAttention(size_t target_parallelism) {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
// hwy::ThreadPool& pool = ctx.pools.Pool();
|
||||
constexpr size_t kOuter = 1024;
|
||||
constexpr size_t kInner = 256;
|
||||
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
|
||||
config.att_cap = 1024.0f;
|
||||
TensorInfoRegistry tensor_info_registry(config);
|
||||
const LayerConfig& layer_config = config.layer_configs[0];
|
||||
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
|
||||
|
|
@ -149,10 +150,17 @@ void TestAttention() {
|
|||
// Copy the output to saved_att to allow for comparison.
|
||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||
SetMat(1, attention.q);
|
||||
FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
||||
qbatch, ctx);
|
||||
AssertClose(attention.att_out, *saved_att);
|
||||
}
|
||||
|
||||
void TestAttention() {
|
||||
TestFlashAttention(8192);
|
||||
TestFlashAttention(2048);
|
||||
TestFlashAttention(256);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
202
ops/ops-inl.h
202
ops/ops-inl.h
|
|
@ -747,7 +747,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
|||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
|
|
@ -882,8 +882,162 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
|||
}
|
||||
i += NF;
|
||||
}
|
||||
const size_t remaining = size - i;
|
||||
HWY_DASSERT(remaining == 0);
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0,
|
||||
const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2,
|
||||
VF& sum3) {
|
||||
sum0 = hn::MulAdd(common, c0, sum0);
|
||||
sum1 = hn::MulAdd(common, c1, sum1);
|
||||
sum2 = hn::MulAdd(common, c2, sum2);
|
||||
sum3 = hn::MulAdd(common, c3, sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0,
|
||||
const VF c1, const VF c2,
|
||||
const VF c3, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
// TODO(rays): Check whether a transpose of c0-c3 is applicable and faster.
|
||||
VF x0 = hn::Load(df, v.Row(pos[0]) + offset);
|
||||
MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1),
|
||||
hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x1 = hn::Load(df, v.Row(pos[1]) + offset);
|
||||
MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1),
|
||||
hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x2 = hn::Load(df, v.Row(pos[2]) + offset);
|
||||
MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1),
|
||||
hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x3 = hn::Load(df, v.Row(pos[3]) + offset);
|
||||
MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1),
|
||||
hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x4 = hn::Load(df, v.Row(pos[4]) + offset);
|
||||
MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1),
|
||||
hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x5 = hn::Load(df, v.Row(pos[5]) + offset);
|
||||
MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1),
|
||||
hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x6 = hn::Load(df, v.Row(pos[6]) + offset);
|
||||
MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1),
|
||||
hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x7 = hn::Load(df, v.Row(pos[7]) + offset);
|
||||
MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1),
|
||||
hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 31)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_GT_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
VF x8 = hn::Load(df, v.Row(pos[8]) + offset);
|
||||
MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1),
|
||||
hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x9 = hn::Load(df, v.Row(pos[9]) + offset);
|
||||
MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1),
|
||||
hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2,
|
||||
sum3);
|
||||
VF x10 = hn::Load(df, v.Row(pos[10]) + offset);
|
||||
MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1),
|
||||
hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x11 = hn::Load(df, v.Row(pos[11]) + offset);
|
||||
MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1),
|
||||
hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x12 = hn::Load(df, v.Row(pos[12]) + offset);
|
||||
MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1),
|
||||
hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x13 = hn::Load(df, v.Row(pos[13]) + offset);
|
||||
MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1),
|
||||
hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x14 = hn::Load(df, v.Row(pos[14]) + offset);
|
||||
MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1),
|
||||
hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
VF x15 = hn::Load(df, v.Row(pos[15]) + offset);
|
||||
MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1),
|
||||
hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1,
|
||||
sum2, sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_V_SIZE_LE_D(DF, 63)>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes(
|
||||
DF df, const MatPtrT<float>& v, const size_t* HWY_RESTRICT pos,
|
||||
const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
VF& sum0, VF& sum1, VF& sum2, VF& sum3) {}
|
||||
|
||||
// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows
|
||||
// of V by the corresponding values in c0-c3 and adds them to NF rows of out,
|
||||
// after first prescaling out by scale.
|
||||
// The depth (size) must be a multiple of NF.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
|
||||
const VF c2, const VF c3, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2 = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3 = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0 = hn::Mul(out0, hn::Set(df, scales[0]));
|
||||
out1 = hn::Mul(out1, hn::Set(df, scales[1]));
|
||||
out2 = hn::Mul(out2, hn::Set(df, scales[2]));
|
||||
out3 = hn::Mul(out3, hn::Set(df, scales[3]));
|
||||
MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 8) {
|
||||
MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3);
|
||||
if HWY_LANES_CONSTEXPR (NF >= 16) {
|
||||
MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2,
|
||||
out3);
|
||||
}
|
||||
}
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
i += NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
|
||||
|
|
@ -898,11 +1052,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const size_t NF = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
while (i + NF <= size) {
|
||||
if constexpr (NF == 16) {
|
||||
if HWY_LANES_CONSTEXPR (NF == 16) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
VF out8, out9, out10, out11, out12, out13, out14, out15;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
|
|
@ -921,22 +1075,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
out13 = hn::Load(df, out + i + out_offsets[13]);
|
||||
out14 = hn::Load(df, out + i + out_offsets[14]);
|
||||
out15 = hn::Load(df, out + i + out_offsets[15]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale));
|
||||
out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale));
|
||||
out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale));
|
||||
out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale));
|
||||
out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale));
|
||||
out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale));
|
||||
out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale));
|
||||
out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale));
|
||||
out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale));
|
||||
out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale));
|
||||
out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale));
|
||||
out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale));
|
||||
Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8,
|
||||
out9, out10, out11, out12, out13, out14, out15);
|
||||
|
|
@ -956,7 +1096,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
hn::Store(out13, df, out + i + out_offsets[13]);
|
||||
hn::Store(out14, df, out + i + out_offsets[14]);
|
||||
hn::Store(out15, df, out + i + out_offsets[15]);
|
||||
} else if constexpr (NF == 8) {
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 8) {
|
||||
VF out0, out1, out2, out3, out4, out5, out6, out7;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
|
|
@ -966,14 +1107,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
out5 = hn::Load(df, out + i + out_offsets[5]);
|
||||
out6 = hn::Load(df, out + i + out_offsets[6]);
|
||||
out7 = hn::Load(df, out + i + out_offsets[7]);
|
||||
out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale));
|
||||
out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale));
|
||||
out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale));
|
||||
out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale));
|
||||
out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale));
|
||||
out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale));
|
||||
out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale));
|
||||
out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale));
|
||||
Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
VF x0 = hn::Load(df, v.Row(pos) + i);
|
||||
MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7);
|
||||
hn::Store(out0, df, out + i + out_offsets[0]);
|
||||
|
|
@ -984,7 +1118,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
hn::Store(out5, df, out + i + out_offsets[5]);
|
||||
hn::Store(out6, df, out + i + out_offsets[6]);
|
||||
hn::Store(out7, df, out + i + out_offsets[7]);
|
||||
} else if constexpr (NF == 4) {
|
||||
}
|
||||
if HWY_LANES_CONSTEXPR (NF == 4) {
|
||||
VF out0, out1, out2, out3;
|
||||
out0 = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1 = hn::Load(df, out + i + out_offsets[1]);
|
||||
|
|
@ -1000,13 +1135,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
|||
hn::Store(out1, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3, df, out + i + out_offsets[3]);
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
}
|
||||
i += NF;
|
||||
}
|
||||
const size_t remaining = size - i;
|
||||
HWY_DASSERT(remaining == 0);
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
|
|
|
|||
Loading…
Reference in New Issue