mirror of https://github.com/google/gemma.cpp.git
Use more parallelism in attention block in prefill mode.
Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.
This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.
Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):
```
Prefill speed
Num threads BEFORE AFTER
32 61.76 t/s 65.08 t/s
64 89.46 t/s 98.62 t/s
```
This commit is contained in:
parent
6eeef2e2d9
commit
3d72f17261
165
gemma/gemma.cc
165
gemma/gemma.cc
|
|
@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
|
|||
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
||||
return CreateKVCache(
|
||||
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen,
|
||||
Config::kSeqLen + kPrefillBatchSize,
|
||||
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
Config::kModelDim,
|
||||
Config::kGriffinLayers * Config::kModelDim);
|
||||
|
|
@ -569,22 +569,23 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t batch_idx, size_t layer,
|
||||
size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr bool kAdd = true;
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
|
||||
// X / Y linear layers.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
|
|
@ -594,9 +595,13 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||
/*out1=*/y, pool);
|
||||
Gelu(y, kModelDim);
|
||||
}
|
||||
|
||||
// Conv1D.
|
||||
{
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
HWY_FULL(float) df;
|
||||
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||
|
|
@ -611,7 +616,8 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||
|
|
@ -628,8 +634,15 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
|
||||
// RGLRU
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT gate_x =
|
||||
activations.griffin_gate_x.data() + batch_offset;
|
||||
float* HWY_RESTRICT a =
|
||||
activations.griffin_multiplier.data() + batch_offset;
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||
|
||||
|
|
@ -673,23 +686,28 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Final linear layer.
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<kAdd, kModelDim, kModelDim>(
|
||||
layer_weights->griffin.linear_out_w, 0, x,
|
||||
layer_weights->griffin.linear_out_biases.data(),
|
||||
activations.even_odd.data(), out_ptr, pool);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
const LayerT* layer_weights, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Attention");
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
|
||||
static constexpr size_t kCachePosSize =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
|
||||
|
|
@ -699,47 +717,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
size_t cache_pos = pos;
|
||||
size_t cache_num = pos + 1;
|
||||
if constexpr (TConfig::kUseLocalAttention) {
|
||||
cache_pos %= TConfig::kSeqLen;
|
||||
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
|
||||
}
|
||||
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
|
||||
auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
|
||||
size_t thread) HWY_ATTR {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
// Calculate scores
|
||||
float* HWY_RESTRICT head_att = activations.att.data() +
|
||||
head * TConfig::kSeqLen +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
head * kSeqLen +
|
||||
batch_idx * kHeads * kSeqLen;
|
||||
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
|
||||
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
}
|
||||
Softmax(head_att, cache_num);
|
||||
Softmax(head_att, std::min(pos + 1, kSeqLen));
|
||||
|
||||
// Weighted summation
|
||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||
batch_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -747,42 +761,66 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
// Multi-Head Attention
|
||||
static_assert(TConfig::kInterleaveQKV);
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
float* HWY_RESTRICT qkv =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
||||
pool);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
Attn(q, head, head * kQKVDim * 2, thread);
|
||||
});
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
|
||||
});
|
||||
} else {
|
||||
// Multi-Query Attention
|
||||
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||
activations.even_odd.data(), q, pool);
|
||||
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
|
||||
cache_pos * kCachePosSize +
|
||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
||||
kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.data(), kv, pool);
|
||||
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
|
||||
Attn(q + head * kQKVDim, head, 0, thread);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t batch_idx = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
// rearranging the weights.
|
||||
float* HWY_RESTRICT att_out =
|
||||
|
|
@ -803,18 +841,19 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
AddFrom(head_out, layer_out, kModelDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||
size_t batch_idx, const LayerT* layer_weights,
|
||||
size_t num_tokens, const LayerT* layer_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(batch_idx < kBatchSize);
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
||||
|
||||
{
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
||||
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
|
||||
|
|
@ -839,12 +878,16 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||
}
|
||||
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->linear_w, 0,
|
||||
activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(), even_odd,
|
||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
}
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
|
|
@ -898,24 +941,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
}
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
|
||||
size_t /*thread*/) HWY_ATTR {
|
||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
|
||||
});
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
}
|
||||
|
|
@ -957,16 +1002,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
|
||||
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
||||
pool);
|
||||
} else {
|
||||
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
|
||||
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
}
|
||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
|
||||
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||
if (layers_output != nullptr) {
|
||||
std::string block_name = "blocks." + std::to_string(layer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue