diff --git a/BUILD.bazel b/BUILD.bazel index 8b0dcde..74f472f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -548,6 +548,7 @@ cc_library( ":gemma_args", ":kv_cache", ":mat", + ":matmul", ":matmul_env", ":model_store", ":ops", diff --git a/gemma/attention.cc b/gemma/attention.cc index e894981..f404674 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -358,7 +358,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { - FlashAttention(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); + FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer, + activations, qbatch, env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index c65c57f..b93b58f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -44,6 +44,7 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/attention.h" +#include "ops/matmul-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -114,6 +115,27 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } } +// Handles a single v row of flash attention for a single q.k dot product. +void HWY_INLINE SingleFlashAttentionStep( + float x, float cap, float& old_max, float& old_d, + const float* HWY_RESTRICT v, const size_t v_cols, + float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { + if (cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + x = cap * std::tanh(x / cap); + } + float m = std::max(x, old_max); + x = std::exp(x - m); + float scale = old_d * std::exp(old_max - m); + old_d = x + scale; + old_max = m; + float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x *= one_over_d; + MulByConst(scale, att_out, v_cols, p, worker); + MulByConstAndAdd(x, v, att_out, v_cols, p, worker); +} + // Calculates the complete attention outputs for a single row of q. void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const float* HWY_RESTRICT q, const MatPtrT& k, @@ -136,21 +158,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); float x = Dot(q, k.Row(pos_mod), k.Cols()); - if (activations.config.att_cap > 0.0f) { - // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. - x = activations.config.att_cap * - std::tanh(x / activations.config.att_cap); - } - float m_new = std::max(m, x); - float scale = d * std::exp(m - m_new); - x = std::exp(x - m_new); - m = m_new; - d = scale + x; - float one_over_d = 1.0f / d; - x *= one_over_d; - scale *= one_over_d; - MulByConst(scale, att_out, v.Cols(), p, worker); - MulByConstAndAdd(x, v.Row(pos_mod), att_out, v.Cols(), p, worker); + SingleFlashAttentionStep(x, activations.config.att_cap, m, d, + v.Row(pos_mod), v.Cols(), att_out, p, worker); } } @@ -167,7 +176,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, return hn::LoadU(df, results); } -// Returns an 8xNF tile of Q.K dot products, in single precision. +// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single +// precision. // This is the result of NF rows of Q against 8 K timesteps, with positions // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // consecutive elements, and other columns by adding q_stride. @@ -240,8 +250,9 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, return hn::Add(sum0, sum2); } -// Sweeps a tile of 8xNF accumulators from start_pos to min_last_pos, then -// sweeps the remaining timesteps in the range (min_last_pos, max_last_pos]. +// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to +// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, +// max_last_pos]. void TileFlashAttention( const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const StridedView& qT, const MatPtrT& k, @@ -260,7 +271,7 @@ void TileFlashAttention( using DI = hn::ScalableTag; const DI di; using VI = hn::Vec; - const int kVTileSize = hn::MaxLanes(df); + const int kVTileSize = hn::Lanes(df); for (int i = 0; i < kVTileSize; ++i) { hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], v.Cols() * sizeof(att_out.Row(0)[0])); @@ -348,38 +359,217 @@ void TileFlashAttention( } } +// Returns an 4 Q rows by NF K tile of Q.K dot products, in single precision. +// This is the result of 4 rows of Q against NF K timesteps, with positions +// given by k_offsets[0..NF]. Q has been transposed so that the 4 rows are read +// in consecutive elements, and other columns by adding q_stride. +template > +void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, + hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + sum0 = hn::Zero(df); + sum1 = hn::Zero(df); + sum2 = hn::Zero(df); + sum3 = hn::Zero(df); + const float* HWY_RESTRICT k_base = k.Row(0); + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + VI k_offsets_vec = hn::LoadU(di, k_offsets); + for (size_t i = 0; i < k.Cols(); ++i) { + VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); + VF q_0 = hn::Set(df, q[0]); + sum0 = hn::MulAdd(q_0, k_vec, sum0); + VF q_1 = hn::Set(df, q[1]); + sum1 = hn::MulAdd(q_1, k_vec, sum1); + VF q_2 = hn::Set(df, q[2]); + sum2 = hn::MulAdd(q_2, k_vec, sum2); + VF q_3 = hn::Set(df, q[3]); + sum3 = hn::MulAdd(q_3, k_vec, sum3); + q += q_stride; + } +} + +// Handles NF v rows of flash attention for NF q.k dot products from one q row. +template > +float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, + float& old_d) { + float m = hn::ReduceMax(df, x); + m = std::max(m, old_max); + x = hn::Exp(df, x - hn::Set(df, m)); + float scale = old_d * std::exp(old_max - m); + old_d = hn::ReduceSum(df, x) + scale; + old_max = m; + float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x = hn::Mul(x, hn::Set(df, one_over_d)); + return scale; +} + +// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to +// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, +// max_last_pos]. +void TileFlashAttention4( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, const AttentionActivations& activations, + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4"); + PROFILER_ZONE3(p, worker, zone); + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + constexpr size_t kMaxNF = hn::MaxLanes(df); + const size_t kHTileSize = hn::Lanes(df); + HWY_DASSERT(kHTileSize <= kMaxNF); + constexpr size_t kVTileSize = 4; + float scales[kVTileSize]; + for (size_t i = 0; i < kVTileSize; ++i) { + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); + } + float old_m0 = -std::numeric_limits::max() / 2.0f; + float old_m1 = -std::numeric_limits::max() / 2.0f; + float old_m2 = -std::numeric_limits::max() / 2.0f; + float old_m3 = -std::numeric_limits::max() / 2.0f; + float old_d0 = 0.0f; + float old_d1 = 0.0f; + float old_d2 = 0.0f; + float old_d3 = 0.0f; + const float* HWY_RESTRICT qT_row = qT.Row(0); + const size_t qT_stride = qT.Stride(); + size_t position = start_pos; + while (position + kHTileSize - 1 <= min_last_pos) { + int32_t k_offsets[kMaxNF]; + size_t v_pos[kMaxNF]; + for (size_t i = 0; i < kHTileSize; ++i) { + v_pos[i] = activations.div_seq_len.Remainder(position + i); + k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); + } + VF x0, x1, x2, x3; + QDotKTilex4(df, qT_row, qT_stride, k, k_offsets, p, worker, x0, x1, x2, x3); + if (activations.config.att_cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. + VF cap = hn::Set(df, activations.config.att_cap); + VF one_over_cap = hn::Div(hn::Set(df, 1.0f), cap); + x0 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x0, one_over_cap))); + x1 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x1, one_over_cap))); + x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); + x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); + } + scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); + scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); + scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); + scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); + MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), + out_offsets, v.Cols(), p, worker); + position += kHTileSize; + } + while (position <= max_last_pos) { + size_t k_pos = activations.div_seq_len.Remainder(position); + if (position <= last_pos[0]) { + // Past the last position, x0 doesn't count. + float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[0], p, worker); + } + if (position <= last_pos[1]) { + // Past the last position, x1 doesn't count. + float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[1], p, worker); + } + if (position <= last_pos[2]) { + // Past the last position, x2 doesn't count. + float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[2], p, worker); + } + if (position <= last_pos[3]) { + // Past the last position, x3 doesn't count. + float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); + SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, + v.Row(k_pos), v.Cols(), + att_out.Row(0) + out_offsets[3], p, worker); + } + ++position; + } +} + +// Rounds n to a number that can be used as the number of Q rows in a tile +// of flash attention. +static size_t RoundToSuitablePowerOf2(size_t n) { + if (n < 4) return 1; + if (n < 8) return 4; + if (n < 16) return 8; + if (n < 32) return 16; + return 32; +} + // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] // into a single output O[L,D]. // Conventional attention first computes A[L,L] = Q . KT // followed by A = softmax(A) (over invididual rows). // Then A is multiplied by V to get O[L,D]. // For each row of O, this takes a read of one row of Q L times, all of K, -// 3 write/reads of one row of A, read all of V, an read.write the one row of O +// 3 write/reads of one row of A, read all of V, and read/write the one row of O // L times. Ignoring the computation for now, and focusing just on memory, // the one row of O takes L(4D+3) reads and L(D+3) writes. // For the whole of Q, this is L^2(4D+3) reads and L^2(D+3) writes. // -// Flash attention fuses these operations together, and (where possible) -// computes NF rows of the result using 8 accumulator registers and two more to -// keep running results. NF is the number of float lanes in a register, being 16 -// for AVX3. The softmax is converted to streaming form using the -// algortihm from: +// Flash attention fuses these operations together, and has 3 operating modes: +// 1. NF rows of the result computed using tiles of registers of shape NFx8. +// 2. 4 rows of the result computed using tiles of registers of shape 4xNF. +// 3. One row (of Q and the result) at a time. +// In all cases the intermediate result (Q.KT) is never stored to memory. +// NF is the number of float lanes in a register, being 16 for AVX3. The softmax +// is converted to streaming form using the algorithm from: // https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf. // Q is transposed to Q_T[D,L] to make the dot product computation efficient. -// QDotKTileFloat computes 8xNF rows of Q.K dot products in one go, reducing -// reads of Q by 8 and reads of K by NF. The streaming softmax is computed -// entirely in registers, and a further NF registers to accumulate the results -// of the product of the softmax and V, reduce the number of reads of V by NF, -// and the reads/writes of O by 8. +// +// In mode 1: +// QDotKTileFloat computes NF Q rows x 8 K timesteps of Q.K dot products in one +// go, reducing reads of Q by 8 and reads of K by NF. The streaming softmax is +// computed entirely in registers, and a further NF registers to accumulate the +// results of the product of the softmax and V, reduce the number of reads of V +// by NF, and the reads/writes of O by 8. // The reads are thus reduced to 2DL^2(1/8+1/NF) and writes reduced to DL^2/8, // which on AVX3 is an overall reduction by about a factor of 10. +// Mode 1 can only be accessed if there is a large Qbatch size, or in multi-turn +// prefill, since in other cases, there is either a single K timestep (prefill) +// or a single num_heads set of Q rows (decode). +// +// In mode 2, the 4 rows of Q are computed against NF K timesteps in a tile, +// reducing the reads of Q by NF, and the reads of K by 4. The softmax and +// accumulation of the result is done in registers, cutting the reads of V by 4. +// The reads/writes of O are reduced by a factor of NF. +// The overall reduction is limited by the need to use gather to load K. +// Transposing K would be possible, but is complicated by the wraparound. +// Mode 2 can be used in all cases when there are at least 4 attention heads, +// but it may be prefereable to use mode 3 when the batch size is small to +// maximise parallelism. +// +// In mode 3, a single row of Q is computed against a single K timestep at a +// time, using SingleFlashAttention. In this case there is no reduction in the +// reads of Q or K, or V, or O, but the reads/writes of the intermediate A are +// still eliminated. // // A further complication is that real attention is not as simple as documented // in the paper and above. There are multiple query heads, differing KV, and // different sequence lengths, so a lot of the work in FlashAttention is making -// sure that a collection of q rows can use the TileFlashAttention path. -void FlashAttention(const size_t num_tokens, const size_t layer_idx, - const LayerWeightsPtrs& layer, +// sure that a collection of q rows with the same KV and sequence length are +// grouped together so that mode 1 or 2 can be used, and choosing which of the +// 3 modes to use for best efficiency. +void FlashAttention(const size_t num_tokens, const size_t target_parallelism, + const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); @@ -392,15 +582,28 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - - using DF = hn::ScalableTag; - const DF df; - constexpr size_t kVTileSize = hn::MaxLanes(df); const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); const size_t total_tasks = token_batch * layer_config.heads; + + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + constexpr size_t kMaxNF = hn::MaxLanes(df); + HWY_DASSERT(kNF <= kMaxNF); + // The vertical tile size is determined by the ability to use tiling and the + // target_parallelism. In practice the possible tile sizes in order of + // preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or + // 16. The final tile size is chosen to be the largest possible that allows + // for target_parallelism parallel tasks. + const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens); + const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1; + const size_t kVTileSize = + (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism) + ? kNF + : std::min(kMinTileSize, kMaxEqualK); // q has shape [batch, qbatch][head, qkv_dim]. // We transpose it to [qkv_dim][qbatch, head, batch] in order to make the // maximum possible number of consecutive columns have the same KV matrices. @@ -416,26 +619,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); // Offsets into original Q for each row in the tile. - uint32_t q_offsets[kVTileSize]; + uint32_t q_offsets[kMaxNF]; // Offsets into att_out for each row in the tile. - uint32_t out_offsets[kVTileSize]; + uint32_t out_offsets[kMaxNF]; // Start positions for each row in the tile. - size_t start_positions[kVTileSize]; + size_t start_positions[kMaxNF]; // Last positions for each row in the tile. Inclusive. - uint32_t last_pos[kVTileSize]; + uint32_t last_pos[kMaxNF]; // min and max last positions across all rows in the tile determines when // TileFlashAttention switches to single vector mode to handle the // ragged sequence lengths. size_t min_last_pos = std::numeric_limits::max(); size_t max_last_pos = 0; // Indices into the qbatch.KV for each row in the tile. - size_t qi_indices[kVTileSize]; + size_t qi_indices[kMaxNF]; // Indices into the kv_cache for each row in the tile. - size_t kv_offsets[kVTileSize]; + size_t kv_offsets[kMaxNF]; // first_task is [qbatch, head, token]. const size_t first_task = task * kVTileSize; const size_t last_task = first_task + kVTileSize - 1; - bool use_tile_attention = last_task < total_tasks; + bool use_tile_attention = kVTileSize > 1 && last_task < total_tasks; for (size_t offset = 0; offset < kVTileSize && first_task + offset < total_tasks; ++offset) { const size_t batch_idx = div_tokens.Remainder(first_task + offset); @@ -486,15 +689,26 @@ void FlashAttention(const size_t num_tokens, const size_t layer_idx, kv_cache.Stride()); if (use_tile_attention) { // To avoid duplicating the code to setup K and V, the call to - // TileFlashAttention is inside the loop over tasks, even thought it + // TileFlashAttention is inside the loop over tasks, even though it // handles all rows in the task at once. StridedView qT = StridedView(activations.q_T.Row(0) + first_task, kVTileSize, activations.q_T.Stride()); - TileFlashAttention( - activations.q, q_offsets, qT, k, start_positions[offset], last_pos, - min_last_pos, max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, worker); + if (kVTileSize == kNF) { + TileFlashAttention(activations.q, q_offsets, qT, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, + worker); + } else if (kVTileSize == 4) { + TileFlashAttention4(activations.q, q_offsets, qT, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx.profiler, + worker); + } else { + HWY_UNREACHABLE; + } break; } else { SingleFlashAttention(start_positions[offset], last_pos[offset], diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index b505d6f..75e087a 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -42,8 +42,8 @@ namespace gcpp { float* HWY_RESTRICT att_out, hwy::Profiler& p, \ size_t worker); \ \ - void FlashAttention(size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ AttentionActivations& activations, QBatch& qbatch, \ ThreadingContext& ctx); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index d4d6380..7f8f31e 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -98,13 +98,14 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { } } -void TestAttention() { +void TestFlashAttention(size_t target_parallelism) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); // hwy::ThreadPool& pool = ctx.pools.Pool(); constexpr size_t kOuter = 1024; constexpr size_t kInner = 256; ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); + config.att_cap = 1024.0f; TensorInfoRegistry tensor_info_registry(config); const LayerConfig& layer_config = config.layer_configs[0]; const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); @@ -149,10 +150,17 @@ void TestAttention() { // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); - FlashAttention(tokens.size(), 0, layers, attention, qbatch, ctx); + FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, + qbatch, ctx); AssertClose(attention.att_out, *saved_att); } +void TestAttention() { + TestFlashAttention(8192); + TestFlashAttention(2048); + TestFlashAttention(256); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/ops-inl.h b/ops/ops-inl.h index ec73f66..a52c788 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -747,7 +747,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - HWY_LANES_CONSTEXPR size_t NF = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); size_t i = 0; while (i + NF <= size) { @@ -882,8 +882,162 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( } i += NF; } - const size_t remaining = size - i; - HWY_DASSERT(remaining == 0); + HWY_DASSERT(size == i); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF c0, + const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, + VF& sum3) { + sum0 = hn::MulAdd(common, c0, sum0); + sum1 = hn::MulAdd(common, c1, sum1); + sum2 = hn::MulAdd(common, c2, sum2); + sum3 = hn::MulAdd(common, c3, sum3); +} + +template > +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4Lanes(DF df, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, + const VF c1, const VF c2, + const VF c3, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { + // TODO(rays): Check whether a transpose of c0-c3 is applicable and faster. + VF x0 = hn::Load(df, v.Row(pos[0]) + offset); + MulAdd4(df, x0, hn::BroadcastLane<0>(c0), hn::BroadcastLane<0>(c1), + hn::BroadcastLane<0>(c2), hn::BroadcastLane<0>(c3), sum0, sum1, sum2, + sum3); + VF x1 = hn::Load(df, v.Row(pos[1]) + offset); + MulAdd4(df, x1, hn::BroadcastLane<1>(c0), hn::BroadcastLane<1>(c1), + hn::BroadcastLane<1>(c2), hn::BroadcastLane<1>(c3), sum0, sum1, sum2, + sum3); + VF x2 = hn::Load(df, v.Row(pos[2]) + offset); + MulAdd4(df, x2, hn::BroadcastLane<2>(c0), hn::BroadcastLane<2>(c1), + hn::BroadcastLane<2>(c2), hn::BroadcastLane<2>(c3), sum0, sum1, sum2, + sum3); + VF x3 = hn::Load(df, v.Row(pos[3]) + offset); + MulAdd4(df, x3, hn::BroadcastLane<3>(c0), hn::BroadcastLane<3>(c1), + hn::BroadcastLane<3>(c2), hn::BroadcastLane<3>(c3), sum0, sum1, sum2, + sum3); +} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x4 = hn::Load(df, v.Row(pos[4]) + offset); + MulAdd4(df, x4, hn::BroadcastLane<4>(c0), hn::BroadcastLane<4>(c1), + hn::BroadcastLane<4>(c2), hn::BroadcastLane<4>(c3), sum0, sum1, sum2, + sum3); + VF x5 = hn::Load(df, v.Row(pos[5]) + offset); + MulAdd4(df, x5, hn::BroadcastLane<5>(c0), hn::BroadcastLane<5>(c1), + hn::BroadcastLane<5>(c2), hn::BroadcastLane<5>(c3), sum0, sum1, sum2, + sum3); + VF x6 = hn::Load(df, v.Row(pos[6]) + offset); + MulAdd4(df, x6, hn::BroadcastLane<6>(c0), hn::BroadcastLane<6>(c1), + hn::BroadcastLane<6>(c2), hn::BroadcastLane<6>(c3), sum0, sum1, sum2, + sum3); + VF x7 = hn::Load(df, v.Row(pos[7]) + offset); + MulAdd4(df, x7, hn::BroadcastLane<7>(c0), hn::BroadcastLane<7>(c1), + hn::BroadcastLane<7>(c2), hn::BroadcastLane<7>(c3), sum0, sum1, sum2, + sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond4Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + VF x8 = hn::Load(df, v.Row(pos[8]) + offset); + MulAdd4(df, x8, hn::BroadcastLane<8>(c0), hn::BroadcastLane<8>(c1), + hn::BroadcastLane<8>(c2), hn::BroadcastLane<8>(c3), sum0, sum1, sum2, + sum3); + VF x9 = hn::Load(df, v.Row(pos[9]) + offset); + MulAdd4(df, x9, hn::BroadcastLane<9>(c0), hn::BroadcastLane<9>(c1), + hn::BroadcastLane<9>(c2), hn::BroadcastLane<9>(c3), sum0, sum1, sum2, + sum3); + VF x10 = hn::Load(df, v.Row(pos[10]) + offset); + MulAdd4(df, x10, hn::BroadcastLane<10>(c0), hn::BroadcastLane<10>(c1), + hn::BroadcastLane<10>(c2), hn::BroadcastLane<10>(c3), sum0, sum1, + sum2, sum3); + VF x11 = hn::Load(df, v.Row(pos[11]) + offset); + MulAdd4(df, x11, hn::BroadcastLane<11>(c0), hn::BroadcastLane<11>(c1), + hn::BroadcastLane<11>(c2), hn::BroadcastLane<11>(c3), sum0, sum1, + sum2, sum3); + VF x12 = hn::Load(df, v.Row(pos[12]) + offset); + MulAdd4(df, x12, hn::BroadcastLane<12>(c0), hn::BroadcastLane<12>(c1), + hn::BroadcastLane<12>(c2), hn::BroadcastLane<12>(c3), sum0, sum1, + sum2, sum3); + VF x13 = hn::Load(df, v.Row(pos[13]) + offset); + MulAdd4(df, x13, hn::BroadcastLane<13>(c0), hn::BroadcastLane<13>(c1), + hn::BroadcastLane<13>(c2), hn::BroadcastLane<13>(c3), sum0, sum1, + sum2, sum3); + VF x14 = hn::Load(df, v.Row(pos[14]) + offset); + MulAdd4(df, x14, hn::BroadcastLane<14>(c0), hn::BroadcastLane<14>(c1), + hn::BroadcastLane<14>(c2), hn::BroadcastLane<14>(c3), sum0, sum1, + sum2, sum3); + VF x15 = hn::Load(df, v.Row(pos[15]) + offset); + MulAdd4(df, x15, hn::BroadcastLane<15>(c0), hn::BroadcastLane<15>(c1), + hn::BroadcastLane<15>(c2), hn::BroadcastLane<15>(c3), sum0, sum1, + sum2, sum3); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAddSecond8Lanes( + DF df, const MatPtrT& v, const size_t* HWY_RESTRICT pos, + const size_t offset, const VF c0, const VF c1, const VF c2, const VF c3, + VF& sum0, VF& sum1, VF& sum2, VF& sum3) {} + +// For an NFx4 tile of float values in 4xNF-lane registers, multiplies NF rows +// of V by the corresponding values in c0-c3 and adds them to NF rows of out, +// after first prescaling out by scale. +// The depth (size) must be a multiple of NF. +template > +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( + DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, + const VF c2, const VF c3, const MatPtrT& v, + const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, + const uint32_t* HWY_RESTRICT out_offsets, const size_t size, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4"); + PROFILER_ZONE3(p, worker, zone); + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + size_t i = 0; + while (i + NF <= size) { + VF out0, out1, out2, out3; + out0 = hn::Load(df, out + i + out_offsets[0]); + out1 = hn::Load(df, out + i + out_offsets[1]); + out2 = hn::Load(df, out + i + out_offsets[2]); + out3 = hn::Load(df, out + i + out_offsets[3]); + out0 = hn::Mul(out0, hn::Set(df, scales[0])); + out1 = hn::Mul(out1, hn::Set(df, scales[1])); + out2 = hn::Mul(out2, hn::Set(df, scales[2])); + out3 = hn::Mul(out3, hn::Set(df, scales[3])); + MulAdd4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 8) { + MulAddSecond4Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, out3); + if HWY_LANES_CONSTEXPR (NF >= 16) { + MulAddSecond8Lanes(df, v, pos, i, c0, c1, c2, c3, out0, out1, out2, + out3); + } + } + hn::Store(out0, df, out + i + out_offsets[0]); + hn::Store(out1, df, out + i + out_offsets[1]); + hn::Store(out2, df, out + i + out_offsets[2]); + hn::Store(out3, df, out + i + out_offsets[3]); + i += NF; + } + HWY_DASSERT(size == i); } // Prescales NF rows of out by scale, then multiplies 1 row of V by the @@ -898,11 +1052,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; - const size_t NF = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); size_t i = 0; while (i + NF <= size) { - if constexpr (NF == 16) { + if HWY_LANES_CONSTEXPR (NF == 16) { VF out0, out1, out2, out3, out4, out5, out6, out7; VF out8, out9, out10, out11, out12, out13, out14, out15; out0 = hn::Load(df, out + i + out_offsets[0]); @@ -921,22 +1075,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); - out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); - out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); - out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); - out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale)); - out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale)); - out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale)); - out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale)); - out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale)); - out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale)); - out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale)); - out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale)); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -956,7 +1096,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out13, df, out + i + out_offsets[13]); hn::Store(out14, df, out + i + out_offsets[14]); hn::Store(out15, df, out + i + out_offsets[15]); - } else if constexpr (NF == 8) { + } + if HWY_LANES_CONSTEXPR (NF == 8) { VF out0, out1, out2, out3, out4, out5, out6, out7; out0 = hn::Load(df, out + i + out_offsets[0]); out1 = hn::Load(df, out + i + out_offsets[1]); @@ -966,14 +1107,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); - out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); - out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); - out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); - out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); - out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); - out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); - out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); hn::Store(out0, df, out + i + out_offsets[0]); @@ -984,7 +1118,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out5, df, out + i + out_offsets[5]); hn::Store(out6, df, out + i + out_offsets[6]); hn::Store(out7, df, out + i + out_offsets[7]); - } else if constexpr (NF == 4) { + } + if HWY_LANES_CONSTEXPR (NF == 4) { VF out0, out1, out2, out3; out0 = hn::Load(df, out + i + out_offsets[0]); out1 = hn::Load(df, out + i + out_offsets[1]); @@ -1000,13 +1135,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( hn::Store(out1, df, out + i + out_offsets[1]); hn::Store(out2, df, out + i + out_offsets[2]); hn::Store(out3, df, out + i + out_offsets[3]); - } else { - HWY_DASSERT(false); } i += NF; } - const size_t remaining = size - i; - HWY_DASSERT(remaining == 0); + HWY_DASSERT(size == i); } // See below for a specialized version for top-1 sampling.