mirror of https://github.com/google/gemma.cpp.git
1.64x batch=1 prefill speedup: nested parallelization for Attention
(DotSoftmaxWeightedSum) Also fix tsan error in matmul (atomic_flag instead of static) PiperOrigin-RevId: 770241705
This commit is contained in:
parent
c027a45a2e
commit
01cdefeda7
|
|
@ -22,6 +22,7 @@
|
|||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -111,7 +112,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
}
|
||||
}
|
||||
|
||||
// Calculates the attention outputs for a single q.
|
||||
// Calculates the attention outputs for a single q, which may be updated
|
||||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||
|
|
@ -158,9 +160,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
NestedPools& pools) {
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
const hwy::Divisor div_queries(queries_pos.size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
|
|
@ -170,37 +173,24 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||
// TODO: nested parallelism to use more threads.
|
||||
pools.Pool(0).Run(
|
||||
0, layer_config.heads * num_tokens * num_queries,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config.heads;
|
||||
const size_t interleaved_idx = task / layer_config.heads;
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.Row(interleaved_idx) + head * qkv_dim;
|
||||
float* HWY_RESTRICT att =
|
||||
activations.att.Row(interleaved_idx) + head * seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||
|
||||
// Make strided views into the kv cache entries for the current
|
||||
// query and head.
|
||||
// Statically partition token/query across packages.
|
||||
const size_t num_tq = num_tokens * div_queries.GetDivisor();
|
||||
const IndexRangePartition tq_ranges =
|
||||
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
||||
ParallelizeOneRange(
|
||||
tq_ranges, pools.AllPackages(),
|
||||
[&](const IndexRange& tq_range, const size_t pkg_idx) {
|
||||
pools.AllClusters(pkg_idx).Run(
|
||||
tq_range.begin(), tq_range.end(),
|
||||
[&](const size_t tq_idx, const size_t cluster_idx) {
|
||||
const size_t query_idx = div_queries.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_queries.Divide(tq_idx);
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
const size_t kv_head_offset =
|
||||
layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
// Find the token position in the query and calculate the range
|
||||
// of cache positions to attend to.
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t start_pos =
|
||||
StartPos(pos, activations.weights_config, layer_idx);
|
||||
|
|
@ -211,9 +201,37 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||
layer_idx, layer, activations, att,
|
||||
att_out);
|
||||
pools.Cluster(pkg_idx, cluster_idx)
|
||||
.Run(
|
||||
0, layer_config.heads,
|
||||
[&](const size_t head, size_t thread) HWY_ATTR {
|
||||
const size_t head_offset =
|
||||
(head / kHeadGroups) * qkv_dim * 2;
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.Row(tq_idx) + head * qkv_dim;
|
||||
|
||||
float* HWY_RESTRICT att =
|
||||
activations.att.Row(tq_idx) + head * seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||
|
||||
// Make strided read-only views into the kv cache for
|
||||
// this query and head.
|
||||
const size_t kv_head_offset =
|
||||
layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
||||
kv_cache.Stride());
|
||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
||||
kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
|
||||
k, v, layer_idx, layer,
|
||||
activations, att, att_out);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -223,8 +241,8 @@ static HWY_INLINE void ComputeQKV(
|
|||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
const hwy::Divisor div_queries(queries_pos.size());
|
||||
const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kv_heads = layer_config.kv_heads;
|
||||
|
|
@ -242,8 +260,8 @@ static HWY_INLINE void ComputeQKV(
|
|||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||
const size_t cache_pos =
|
||||
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
|
|
@ -255,14 +273,15 @@ static HWY_INLINE void ComputeQKV(
|
|||
/*add=*/nullptr, env, kv_rows);
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// TODO: 2D parallelism to use more threads.
|
||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
// tasks are very lightweight.
|
||||
env.ctx.pools.Pool(0).Run(
|
||||
0, kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -401,9 +402,8 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
|
|||
}
|
||||
// This happens in tests with small N, hence do not assert.
|
||||
if (N % (np_multiple * num_packages) && N >= 128) {
|
||||
static bool warned = false;
|
||||
if (!warned) {
|
||||
warned = true;
|
||||
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||
if (!warned.test_and_set()) {
|
||||
HWY_WARN(
|
||||
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
|
||||
"num_packages=%zu\n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue