1.64x batch=1 prefill speedup: nested parallelization for Attention

(DotSoftmaxWeightedSum)
Also fix tsan error in matmul (atomic_flag instead of static)

PiperOrigin-RevId: 770241705
This commit is contained in:
Jan Wassenberg 2025-06-11 11:28:14 -07:00 committed by Copybara-Service
parent c027a45a2e
commit 01cdefeda7
2 changed files with 72 additions and 53 deletions

View File

@ -22,6 +22,7 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -111,7 +112,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} }
} }
// Calculates the attention outputs for a single q. // Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
@ -158,9 +160,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches, Activations& activations, const KVCaches& kv_caches,
NestedPools& pools) { NestedPools& pools) {
const size_t num_queries = queries_pos.size();
const LayerConfig& layer_config = layer.layer_config;
PROFILER_ZONE("Gen.Attention.DotSoftmax"); PROFILER_ZONE("Gen.Attention.DotSoftmax");
const hwy::Divisor div_queries(queries_pos.size());
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads. // heads that share the same key and value heads.
@ -170,37 +173,24 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor()); static_cast<size_t>(activations.div_seq_len.GetDivisor());
// For each head (token, query), compute Q.K, softmax, and weighted V. // For each head/token/query, compute Q.K, softmax, and weighted V.
// TODO: nested parallelism to use more threads.
pools.Pool(0).Run(
0, layer_config.heads * num_tokens * num_queries,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config.heads;
const size_t interleaved_idx = task / layer_config.heads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q = // Statically partition token/query across packages.
activations.q.Row(interleaved_idx) + head * qkv_dim; const size_t num_tq = num_tokens * div_queries.GetDivisor();
float* HWY_RESTRICT att = const IndexRangePartition tq_ranges =
activations.att.Row(interleaved_idx) + head * seq_len; StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
float* HWY_RESTRICT att_out = ParallelizeOneRange(
activations.att_out.Row(interleaved_idx) + head * qkv_dim; tq_ranges, pools.AllPackages(),
[&](const IndexRange& tq_range, const size_t pkg_idx) {
// Make strided views into the kv cache entries for the current pools.AllClusters(pkg_idx).Run(
// query and head. tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) {
const size_t query_idx = div_queries.Remainder(tq_idx);
const size_t batch_idx = div_queries.Divide(tq_idx);
auto& kv_cache = kv_caches[query_idx].kv_cache; auto& kv_cache = kv_caches[query_idx].kv_cache;
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
// Find the token position in the query and calculate the range // Find the token position in the query and calculate
// of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t start_pos = const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx); StartPos(pos, activations.weights_config, layer_idx);
@ -211,9 +201,37 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
last_pos = prefix_end - 1; last_pos = prefix_end - 1;
} }
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, pools.Cluster(pkg_idx, cluster_idx)
layer_idx, layer, activations, att, .Run(
att_out); 0, layer_config.heads,
[&](const size_t head, size_t thread) HWY_ATTR {
const size_t head_offset =
(head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q =
activations.q.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT att =
activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;
// Make strided read-only views into the kv cache for
// this query and head.
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
kv_cache.Stride());
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
k, v, layer_idx, layer,
activations, att, att_out);
});
});
}); });
} }
@ -223,8 +241,8 @@ static HWY_INLINE void ComputeQKV(
const LayerWeightsPtrs& layer, Activations& activations, const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, const int flags, MatMulEnv& env) { const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
const size_t num_queries = queries_pos.size(); const hwy::Divisor div_queries(queries_pos.size());
const size_t num_interleaved = num_tokens * num_queries; const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads; const size_t kv_heads = layer_config.kv_heads;
@ -242,8 +260,8 @@ static HWY_INLINE void ComputeQKV(
layer.qkv_einsum_w2.Rows())); layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
const size_t query_idx = interleaved_idx % num_queries; const size_t query_idx = div_queries.Remainder(interleaved_idx);
const size_t batch_idx = interleaved_idx / num_queries; const size_t batch_idx = div_queries.Divide(interleaved_idx);
const size_t cache_pos = const size_t cache_pos =
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>( env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
@ -255,14 +273,15 @@ static HWY_INLINE void ComputeQKV(
/*add=*/nullptr, env, kv_rows); /*add=*/nullptr, env, kv_rows);
// Apply positional encodings for K. // Apply positional encodings for K.
// TODO: 2D parallelism to use more threads. // Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
env.ctx.pools.Pool(0).Run( env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved, 0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = interleaved_idx % num_queries; const size_t query_idx = div_queries.Remainder(interleaved_idx);
const size_t batch_idx = interleaved_idx / num_queries; const size_t batch_idx = div_queries.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos); const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = kv_caches[query_idx].kv_cache; auto& kv_cache = kv_caches[query_idx].kv_cache;

View File

@ -21,6 +21,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <atomic>
#include <vector> #include <vector>
#include "util/allocator.h" #include "util/allocator.h"
@ -401,9 +402,8 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
} }
// This happens in tests with small N, hence do not assert. // This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) { if (N % (np_multiple * num_packages) && N >= 128) {
static bool warned = false; static std::atomic_flag warned = ATOMIC_FLAG_INIT;
if (!warned) { if (!warned.test_and_set()) {
warned = true;
HWY_WARN( HWY_WARN(
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * " "NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
"num_packages=%zu\n", "num_packages=%zu\n",