diff --git a/gemma/activations.h b/gemma/activations.h index cfd174f..a96c305 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -159,6 +159,10 @@ struct AttentionActivationsPtrs { // `inv_timescale*` are not batched. } + size_t SeqLen() const { + return static_cast(div_seq_len.GetDivisor()); + } + const ModelConfig& config; MatPtrT q; MatPtrT q_bf; diff --git a/gemma/attention.cc b/gemma/attention.cc index 95d62cd..ad464e7 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -66,20 +66,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); - if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float score = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); - att[pos] = score; - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t pos_modulo = div_seq_len.Remainder(pos); - const float score = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim); - att[pos_modulo] = score; - } + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const float score = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); + att[pos] = score; } } @@ -114,25 +106,13 @@ static HWY_INLINE void WeightedSumV( const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, const MatPtrT& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { - if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { - // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if - // we supported non-transposed B. - // TODO: 2..4x unroll - MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, - worker); - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); - } - } else { - { - const size_t pos_mod = div_seq_len.Remainder(start_pos); - MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx, - worker); - } - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - const size_t pos_mod = div_seq_len.Remainder(pos); - MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols()); - } + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + // TODO: replace with MatMul(att, v) after it supports non-transposed B. + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, + worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); } } @@ -146,9 +126,10 @@ void SingleDotSoftmaxWeightedSum( float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < activations.SeqLen()); const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + // Apply rope and scaling to Q. if (query_norm_scale.HasPtr()) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { @@ -163,8 +144,7 @@ void SingleDotSoftmaxWeightedSum( QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); // SoftMax with optional SoftCap yields "probabilities" in att. - const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - const Logits logits(att, att_len); + const Logits logits(att, last_pos + 1); MaybeLogitsSoftCap(att_cap, logits, ctx, worker); Softmax(logits, ctx, worker, /*temperature=*/1.0f); @@ -194,8 +174,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); + const size_t seq_len = activations.SeqLen(); // All layers should have the same number of heads. HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); @@ -284,8 +263,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, ++interleaved_idx) { const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t cache_pos = - activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); + const size_t cache_pos = qbatch.Pos(qi) + batch_idx; + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(cache_pos < activations.SeqLen()); + env.row_ptrs[0][interleaved_idx] = reinterpret_cast( qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } @@ -304,8 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - const size_t cache_pos = activations.div_seq_len.Remainder(pos); + const size_t cache_pos = qbatch.Pos(qi) + batch_idx; + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(cache_pos < activations.SeqLen()); auto& kv_cache = qbatch.KV(qi).kv_cache; KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + @@ -325,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, } PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, - pos, /*mul=*/1.0f); + cache_pos, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 6548537..b9b0c8a 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -716,7 +716,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, size_t last = pos; const size_t prefix_end = qbatch.PrefixEnd(qi); if (prefix_end > 0 && prefix_end - 1 > last) { - // last_pos in QDotK and WeightedSumV is inclusive. + // last_pos in `TileFlashAttention` is inclusive. last = prefix_end - 1; } last_pos[offset] = last; diff --git a/util/basics.h b/util/basics.h index 5a7f0d5..49996ba 100644 --- a/util/basics.h +++ b/util/basics.h @@ -33,6 +33,9 @@ namespace gcpp { // For hwy::BitSet4096. Note that KVs are extremely large for such batches. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; +// Multiplier so a u64 occupies an entire cache line; avoids false sharing. +HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t); + enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; static inline const char* ToString(Tristate t) { diff --git a/util/threading_context.cc b/util/threading_context.cc index e725ce3..d3aa74f 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1, num_workers * 5, num_workers * 20}; - // Count tasks executed to ensure workers aren't optimized out. One per - // cache line to avoid false sharing. - const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t); - - std::vector counters(num_workers * kSizePerLine); - size_t prev_total = 0; // avoids having to reset counters. + // Count tasks executed to ensure workers aren't optimized out. + std::vector counters(num_workers * kU64PerLine); + uint64_t prev_total = 0; // avoids having to reset counters. hwy::RandomState rng; for (size_t rep = 0; rep < 500; ++rep) { @@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { pool.Run(begin, end, [&](uint64_t task, size_t thread) { HWY_ASSERT(begin <= task && task < end); HWY_ASSERT(thread < num_workers); - counters[thread * kSizePerLine]++; + counters[thread * kU64PerLine]++; }); // Reduce count and ensure it matches the expected number of tasks. - size_t total = 0; + uint64_t total = 0; for (size_t i = 0; i < num_workers; ++i) { - total += counters[i * kSizePerLine]; + total += counters[i * kU64PerLine]; } const size_t expected = end - begin; HWY_ASSERT(total == prev_total + expected); diff --git a/util/threading_test.cc b/util/threading_test.cc index d6b98a3..b5d1858 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -202,8 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) { } } -static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t); -static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread]; +static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine]; std::vector MeasureForkJoin(hwy::ThreadPool& pool) { // Governs duration of test; avoid timeout in debug builds. @@ -217,7 +216,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const double t0 = hwy::platform::Now(); for (size_t reps = 0; reps < 1200; ++reps) { pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); hwy::PreventElision(outputs[base]); if (pool.AutoTuneComplete()) break; @@ -258,7 +257,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const uint64_t t0 = hwy::timer::Start(); pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); const uint64_t t1 = hwy::timer::Stop(); times.push_back(t1 - t0); @@ -268,7 +267,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const uint64_t t0 = hwy::timer::Start(); pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); const uint64_t t1 = hwy::timer::Start(); times.push_back(t1 - t0); @@ -315,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) { // Verify outputs to ensure the measured code is not a no-op. for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) { - HWY_ASSERT(outputs[lp * kU64PerThread] >= 1); - HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers()); - for (size_t i = 1; i < kU64PerThread; ++i) { - HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0); + HWY_ASSERT(outputs[lp * kU64PerLine] >= 1); + HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers()); + for (size_t i = 1; i < kU64PerLine; ++i) { + HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0); } } };