mirror of https://github.com/google/gemma.cpp.git
1.64x batch=1 prefill speedup: nested parallelization for Attention
(DotSoftmaxWeightedSum) Also fix tsan error in matmul (atomic_flag instead of static) PiperOrigin-RevId: 770241705
This commit is contained in:
parent
c027a45a2e
commit
01cdefeda7
|
|
@ -22,6 +22,7 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/threading.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -111,7 +112,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculates the attention outputs for a single q.
|
// Calculates the attention outputs for a single q, which may be updated
|
||||||
|
// in place for RMSNorm.
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||||
|
|
@ -158,9 +160,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
Activations& activations, const KVCaches& kv_caches,
|
Activations& activations, const KVCaches& kv_caches,
|
||||||
NestedPools& pools) {
|
NestedPools& pools) {
|
||||||
const size_t num_queries = queries_pos.size();
|
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||||
|
const hwy::Divisor div_queries(queries_pos.size());
|
||||||
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
// A "head group" in the context of GQA refers to a collection of query
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
// heads that share the same key and value heads.
|
// heads that share the same key and value heads.
|
||||||
|
|
@ -170,37 +173,24 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
const size_t seq_len =
|
const size_t seq_len =
|
||||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||||
// TODO: nested parallelism to use more threads.
|
|
||||||
pools.Pool(0).Run(
|
|
||||||
0, layer_config.heads * num_tokens * num_queries,
|
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
|
||||||
const size_t head = task % layer_config.heads;
|
|
||||||
const size_t interleaved_idx = task / layer_config.heads;
|
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
|
||||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
// Statically partition token/query across packages.
|
||||||
activations.q.Row(interleaved_idx) + head * qkv_dim;
|
const size_t num_tq = num_tokens * div_queries.GetDivisor();
|
||||||
float* HWY_RESTRICT att =
|
const IndexRangePartition tq_ranges =
|
||||||
activations.att.Row(interleaved_idx) + head * seq_len;
|
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
||||||
float* HWY_RESTRICT att_out =
|
ParallelizeOneRange(
|
||||||
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
|
tq_ranges, pools.AllPackages(),
|
||||||
|
[&](const IndexRange& tq_range, const size_t pkg_idx) {
|
||||||
// Make strided views into the kv cache entries for the current
|
pools.AllClusters(pkg_idx).Run(
|
||||||
// query and head.
|
tq_range.begin(), tq_range.end(),
|
||||||
|
[&](const size_t tq_idx, const size_t cluster_idx) {
|
||||||
|
const size_t query_idx = div_queries.Remainder(tq_idx);
|
||||||
|
const size_t batch_idx = div_queries.Divide(tq_idx);
|
||||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||||
const size_t kv_head_offset =
|
|
||||||
layer_idx * cache_layer_size + head_offset;
|
|
||||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
|
||||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
|
||||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
|
||||||
|
|
||||||
// Find the token position in the query and calculate the range
|
// Find the token position in the query and calculate
|
||||||
// of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
StartPos(pos, activations.weights_config, layer_idx);
|
StartPos(pos, activations.weights_config, layer_idx);
|
||||||
|
|
@ -211,9 +201,37 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
pools.Cluster(pkg_idx, cluster_idx)
|
||||||
layer_idx, layer, activations, att,
|
.Run(
|
||||||
att_out);
|
0, layer_config.heads,
|
||||||
|
[&](const size_t head, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head_offset =
|
||||||
|
(head / kHeadGroups) * qkv_dim * 2;
|
||||||
|
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.Row(tq_idx) + head * qkv_dim;
|
||||||
|
|
||||||
|
float* HWY_RESTRICT att =
|
||||||
|
activations.att.Row(tq_idx) + head * seq_len;
|
||||||
|
float* HWY_RESTRICT att_out =
|
||||||
|
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||||
|
|
||||||
|
// Make strided read-only views into the kv cache for
|
||||||
|
// this query and head.
|
||||||
|
const size_t kv_head_offset =
|
||||||
|
layer_idx * cache_layer_size + head_offset;
|
||||||
|
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
||||||
|
kv_cache.Stride());
|
||||||
|
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
||||||
|
kv_cache.Stride());
|
||||||
|
|
||||||
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
|
||||||
|
k, v, layer_idx, layer,
|
||||||
|
activations, att, att_out);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -223,8 +241,8 @@ static HWY_INLINE void ComputeQKV(
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Attention.QKV");
|
PROFILER_ZONE("Gen.Attention.QKV");
|
||||||
const size_t num_queries = queries_pos.size();
|
const hwy::Divisor div_queries(queries_pos.size());
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
const size_t kv_heads = layer_config.kv_heads;
|
const size_t kv_heads = layer_config.kv_heads;
|
||||||
|
|
@ -242,8 +260,8 @@ static HWY_INLINE void ComputeQKV(
|
||||||
layer.qkv_einsum_w2.Rows()));
|
layer.qkv_einsum_w2.Rows()));
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||||
const size_t cache_pos =
|
const size_t cache_pos =
|
||||||
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||||
|
|
@ -255,14 +273,15 @@ static HWY_INLINE void ComputeQKV(
|
||||||
/*add=*/nullptr, env, kv_rows);
|
/*add=*/nullptr, env, kv_rows);
|
||||||
|
|
||||||
// Apply positional encodings for K.
|
// Apply positional encodings for K.
|
||||||
// TODO: 2D parallelism to use more threads.
|
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||||
|
// tasks are very lightweight.
|
||||||
env.ctx.pools.Pool(0).Run(
|
env.ctx.pools.Pool(0).Run(
|
||||||
0, kv_heads * num_interleaved,
|
0, kv_heads * num_interleaved,
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
@ -401,9 +402,8 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
|
||||||
}
|
}
|
||||||
// This happens in tests with small N, hence do not assert.
|
// This happens in tests with small N, hence do not assert.
|
||||||
if (N % (np_multiple * num_packages) && N >= 128) {
|
if (N % (np_multiple * num_packages) && N >= 128) {
|
||||||
static bool warned = false;
|
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||||
if (!warned) {
|
if (!warned.test_and_set()) {
|
||||||
warned = true;
|
|
||||||
HWY_WARN(
|
HWY_WARN(
|
||||||
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
|
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
|
||||||
"num_packages=%zu\n",
|
"num_packages=%zu\n",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue