diff --git a/gemma/attention.cc b/gemma/attention.cc index 55ac12b..c0cce57 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -22,6 +22,7 @@ #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "gemma/weights.h" +#include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -111,7 +112,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, } } -// Calculates the attention outputs for a single q. +// Calculates the attention outputs for a single q, which may be updated +// in place for RMSNorm. void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, @@ -158,9 +160,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const LayerWeightsPtrs& layer, Activations& activations, const KVCaches& kv_caches, NestedPools& pools) { - const size_t num_queries = queries_pos.size(); - const LayerConfig& layer_config = layer.layer_config; PROFILER_ZONE("Gen.Attention.DotSoftmax"); + const hwy::Divisor div_queries(queries_pos.size()); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. @@ -170,50 +173,65 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); - // For each head (token, query), compute Q.K, softmax, and weighted V. - // TODO: nested parallelism to use more threads. - pools.Pool(0).Run( - 0, layer_config.heads * num_tokens * num_queries, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config.heads; - const size_t interleaved_idx = task / layer_config.heads; - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; - const size_t qkv_dim = layer_config.qkv_dim; - const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + // For each head/token/query, compute Q.K, softmax, and weighted V. - float* HWY_RESTRICT q = - activations.q.Row(interleaved_idx) + head * qkv_dim; - float* HWY_RESTRICT att = - activations.att.Row(interleaved_idx) + head * seq_len; - float* HWY_RESTRICT att_out = - activations.att_out.Row(interleaved_idx) + head * qkv_dim; + // Statically partition token/query across packages. + const size_t num_tq = num_tokens * div_queries.GetDivisor(); + const IndexRangePartition tq_ranges = + StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1); + ParallelizeOneRange( + tq_ranges, pools.AllPackages(), + [&](const IndexRange& tq_range, const size_t pkg_idx) { + pools.AllClusters(pkg_idx).Run( + tq_range.begin(), tq_range.end(), + [&](const size_t tq_idx, const size_t cluster_idx) { + const size_t query_idx = div_queries.Remainder(tq_idx); + const size_t batch_idx = div_queries.Divide(tq_idx); + auto& kv_cache = kv_caches[query_idx].kv_cache; - // Make strided views into the kv cache entries for the current - // query and head. - auto& kv_cache = kv_caches[query_idx].kv_cache; - const size_t kv_head_offset = - layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); - k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); - v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t start_pos = + StartPos(pos, activations.weights_config, layer_idx); + size_t last_pos = pos; + const size_t prefix_end = queries_prefix_end[query_idx]; + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } - // Find the token position in the query and calculate the range - // of cache positions to attend to. - const size_t pos = queries_pos[query_idx] + batch_idx; - const size_t start_pos = - StartPos(pos, activations.weights_config, layer_idx); - size_t last_pos = pos; - const size_t prefix_end = queries_prefix_end[query_idx]; - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } + pools.Cluster(pkg_idx, cluster_idx) + .Run( + 0, layer_config.heads, + [&](const size_t head, size_t thread) HWY_ATTR { + const size_t head_offset = + (head / kHeadGroups) * qkv_dim * 2; - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, - layer_idx, layer, activations, att, - att_out); + float* HWY_RESTRICT q = + activations.q.Row(tq_idx) + head * qkv_dim; + + float* HWY_RESTRICT att = + activations.att.Row(tq_idx) + head * seq_len; + float* HWY_RESTRICT att_out = + activations.att_out.Row(tq_idx) + head * qkv_dim; + + // Make strided read-only views into the kv cache for + // this query and head. + const size_t kv_head_offset = + layer_idx * cache_layer_size + head_offset; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_head_offset, + kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, + kv_cache.Stride()); + + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, + k, v, layer_idx, layer, + activations, att, att_out); + }); + }); }); } @@ -223,8 +241,8 @@ static HWY_INLINE void ComputeQKV( const LayerWeightsPtrs& layer, Activations& activations, const KVCaches& kv_caches, const int flags, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.QKV"); - const size_t num_queries = queries_pos.size(); - const size_t num_interleaved = num_tokens * num_queries; + const hwy::Divisor div_queries(queries_pos.size()); + const size_t num_interleaved = num_tokens * div_queries.GetDivisor(); const LayerConfig& layer_config = layer.layer_config; const size_t qkv_dim = layer_config.qkv_dim; const size_t kv_heads = layer_config.kv_heads; @@ -242,8 +260,8 @@ static HWY_INLINE void ComputeQKV( layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; + const size_t query_idx = div_queries.Remainder(interleaved_idx); + const size_t batch_idx = div_queries.Divide(interleaved_idx); const size_t cache_pos = activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); env.row_ptrs[0][interleaved_idx] = reinterpret_cast( @@ -255,14 +273,15 @@ static HWY_INLINE void ComputeQKV( /*add=*/nullptr, env, kv_rows); // Apply positional encodings for K. - // TODO: 2D parallelism to use more threads. + // Note that 2D parallelism is not worth the fork/join overhead because the + // tasks are very lightweight. env.ctx.pools.Pool(0).Run( 0, kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; + const size_t query_idx = div_queries.Remainder(interleaved_idx); + const size_t batch_idx = div_queries.Divide(interleaved_idx); const size_t pos = queries_pos[query_idx] + batch_idx; const size_t cache_pos = activations.div_seq_len.Remainder(pos); auto& kv_cache = kv_caches[query_idx].kv_cache; diff --git a/ops/matmul.cc b/ops/matmul.cc index 21a1b91..7e83830 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "util/allocator.h" @@ -401,9 +402,8 @@ static size_t NPMultiple(const Allocator& allocator, size_t N, } // This happens in tests with small N, hence do not assert. if (N % (np_multiple * num_packages) && N >= 128) { - static bool warned = false; - if (!warned) { - warned = true; + static std::atomic_flag warned = ATOMIC_FLAG_INIT; + if (!warned.test_and_set()) { HWY_WARN( "NPMultiple: N=%zu still not divisible by np_multiple=%zu * " "num_packages=%zu\n",