Change (old) attention behavior to disallow wraparound, enforced via assertion.

Shared kU64PerLine constant

PiperOrigin-RevId: 828072451
This commit is contained in:
Jan Wassenberg 2025-11-04 11:52:07 -08:00 committed by Copybara-Service
parent 3a63a12624
commit a344a70c59
6 changed files with 48 additions and 63 deletions

View File

@ -159,6 +159,10 @@ struct AttentionActivationsPtrs {
// `inv_timescale*` are not batched. // `inv_timescale*` are not batched.
} }
size_t SeqLen() const {
return static_cast<size_t>(div_seq_len.GetDivisor());
}
const ModelConfig& config; const ModelConfig& config;
MatPtrT<float> q; MatPtrT<float> q;
MatPtrT<BF16> q_bf; MatPtrT<BF16> q_bf;

View File

@ -66,21 +66,13 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0); 0);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { // --seq_len must be large enough to avoid wraparound.
// Slightly faster: no wraparound. HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
att[pos] = score; att[pos] = score;
} }
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
att[pos_modulo] = score;
}
}
} }
void PositionalEncodingQK(float* qk, const size_t layer_idx, void PositionalEncodingQK(float* qk, const size_t layer_idx,
@ -114,26 +106,14 @@ static HWY_INLINE void WeightedSumV(
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { // --seq_len must be large enough to avoid wraparound.
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
// we supported non-transposed B. // TODO: replace with MatMul(att, v) after it supports non-transposed B.
// TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
worker); worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
} }
} else {
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
}
}
} }
// Calculates the attention outputs for a single q, which may be updated // Calculates the attention outputs for a single q, which may be updated
@ -146,9 +126,10 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = // --seq_len must be large enough to avoid wraparound.
static_cast<size_t>(activations.div_seq_len.GetDivisor()); HWY_DASSERT(last_pos < activations.SeqLen());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
@ -163,8 +144,7 @@ void SingleDotSoftmaxWeightedSum(
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const Logits logits(att, last_pos + 1);
const Logits logits(att, att_len);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker); MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f); Softmax(logits, ctx, worker, /*temperature=*/1.0f);
@ -194,8 +174,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len = const size_t seq_len = activations.SeqLen();
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// All layers should have the same number of heads. // All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
@ -284,8 +263,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
++interleaved_idx) { ++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>( env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
} }
@ -304,8 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos); // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size + layer_idx * cache_layer_size +
@ -325,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
} }
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
pos, /*mul=*/1.0f); cache_pos, /*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });

View File

@ -716,7 +716,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
size_t last = pos; size_t last = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi); const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last) { if (prefix_end > 0 && prefix_end - 1 > last) {
// last_pos in QDotK and WeightedSumV is inclusive. // last_pos in `TileFlashAttention` is inclusive.
last = prefix_end - 1; last = prefix_end - 1;
} }
last_pos[offset] = last; last_pos[offset] = last;

View File

@ -33,6 +33,9 @@ namespace gcpp {
// For hwy::BitSet4096. Note that KVs are extremely large for such batches. // For hwy::BitSet4096. Note that KVs are extremely large for such batches.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
// Multiplier so a u64 occupies an entire cache line; avoids false sharing.
HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
static inline const char* ToString(Tristate t) { static inline const char* ToString(Tristate t) {

View File

@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1, const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
num_workers * 5, num_workers * 20}; num_workers * 5, num_workers * 20};
// Count tasks executed to ensure workers aren't optimized out. One per // Count tasks executed to ensure workers aren't optimized out.
// cache line to avoid false sharing. std::vector<uint64_t> counters(num_workers * kU64PerLine);
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t); uint64_t prev_total = 0; // avoids having to reset counters.
std::vector<size_t> counters(num_workers * kSizePerLine);
size_t prev_total = 0; // avoids having to reset counters.
hwy::RandomState rng; hwy::RandomState rng;
for (size_t rep = 0; rep < 500; ++rep) { for (size_t rep = 0; rep < 500; ++rep) {
@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
pool.Run(begin, end, [&](uint64_t task, size_t thread) { pool.Run(begin, end, [&](uint64_t task, size_t thread) {
HWY_ASSERT(begin <= task && task < end); HWY_ASSERT(begin <= task && task < end);
HWY_ASSERT(thread < num_workers); HWY_ASSERT(thread < num_workers);
counters[thread * kSizePerLine]++; counters[thread * kU64PerLine]++;
}); });
// Reduce count and ensure it matches the expected number of tasks. // Reduce count and ensure it matches the expected number of tasks.
size_t total = 0; uint64_t total = 0;
for (size_t i = 0; i < num_workers; ++i) { for (size_t i = 0; i < num_workers; ++i) {
total += counters[i * kSizePerLine]; total += counters[i * kU64PerLine];
} }
const size_t expected = end - begin; const size_t expected = end - begin;
HWY_ASSERT(total == prev_total + expected); HWY_ASSERT(total == prev_total + expected);

View File

@ -202,8 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) {
} }
} }
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t); static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine];
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) { std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
// Governs duration of test; avoid timeout in debug builds. // Governs duration of test; avoid timeout in debug builds.
@ -217,7 +216,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
for (size_t reps = 0; reps < 1200; ++reps) { for (size_t reps = 0; reps < 1200; ++reps) {
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
hwy::PreventElision(outputs[base]); hwy::PreventElision(outputs[base]);
if (pool.AutoTuneComplete()) break; if (pool.AutoTuneComplete()) break;
@ -258,7 +257,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), kCaller, pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
const uint64_t t1 = hwy::timer::Stop(); const uint64_t t1 = hwy::timer::Stop();
times.push_back(t1 - t0); times.push_back(t1 - t0);
@ -268,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
pool.Run(0, pool.NumWorkers(), kCaller, pool.Run(0, pool.NumWorkers(), kCaller,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
outputs[thread * kU64PerThread] = base + thread; outputs[thread * kU64PerLine] = base + thread;
}); });
const uint64_t t1 = hwy::timer::Start(); const uint64_t t1 = hwy::timer::Start();
times.push_back(t1 - t0); times.push_back(t1 - t0);
@ -315,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) {
// Verify outputs to ensure the measured code is not a no-op. // Verify outputs to ensure the measured code is not a no-op.
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) { for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
HWY_ASSERT(outputs[lp * kU64PerThread] >= 1); HWY_ASSERT(outputs[lp * kU64PerLine] >= 1);
HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers()); HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers());
for (size_t i = 1; i < kU64PerThread; ++i) { for (size_t i = 1; i < kU64PerLine; ++i) {
HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0); HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0);
} }
} }
}; };