Use more parallelism in attention block in prefill mode.

Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.

This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed
Num threads      BEFORE       AFTER
32               61.76 t/s    65.08 t/s
64               89.46 t/s    98.62 t/s
```
This commit is contained in:
Zoltan Szabadka 2024-05-03 12:52:23 +00:00
parent 6eeef2e2d9
commit 3d72f17261
1 changed files with 199 additions and 154 deletions

View File

@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth; constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache( return CreateKVCache(
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen, Config::kSeqLen + kPrefillBatchSize,
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim, Config::kModelDim,
Config::kGriffinLayers * Config::kModelDim); Config::kGriffinLayers * Config::kModelDim);
@ -569,22 +569,23 @@ namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent( HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer, size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights, Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) { KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim = static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr bool kAdd = true; static constexpr bool kAdd = true;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers. // X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<kAdd, kModelDim, kModelDim>( TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
@ -594,9 +595,13 @@ HWY_NOINLINE void GriffinRecurrent(
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool); /*out1=*/y, pool);
Gelu(y, kModelDim); Gelu(y, kModelDim);
}
// Conv1D. // Conv1D.
{ for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
HWY_FULL(float) df; HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0); HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
@ -611,7 +616,8 @@ HWY_NOINLINE void GriffinRecurrent(
} }
for (size_t i = 0; i < kModelDim; i += Lanes(df)) { for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i); auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
auto accum1 = hn::Zero(df); auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) { for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
@ -628,8 +634,15 @@ HWY_NOINLINE void GriffinRecurrent(
} }
// RGLRU // RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset; const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT gate_x =
activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a =
activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state = float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim; kv_cache.rglru_cache.get() + layer * kModelDim;
@ -673,23 +686,28 @@ HWY_NOINLINE void GriffinRecurrent(
hn::Store(pre_out, df, x + head_offset + i); hn::Store(pre_out, df, x + head_offset + i);
} }
}); });
}
// Final linear layer. // Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<kAdd, kModelDim, kModelDim>( MatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x, layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(), layer_weights->griffin.linear_out_biases.data(),
activations.even_odd.data(), out_ptr, pool); activations.even_odd.data(), out_ptr, pool);
}
} }
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
const LayerT* layer_weights, KVCache& kv_cache, const LayerT* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx; HWY_DASSERT(num_tokens <= kBatchSize);
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim; static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t kCachePosSize = static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize; gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
@ -699,47 +717,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static const float kQueryScale = static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos; auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
size_t thread) HWY_ATTR { size_t thread) HWY_ATTR {
const size_t pos = batch_start + batch_idx;
// Calculate scores // Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() + float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen + head * kSeqLen +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kSeqLen;
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
const size_t cache_offset = for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset; const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2 % kSeqLen] = score;
} }
Softmax(head_att, cache_num); Softmax(head_att, std::min(pos + 1, kSeqLen));
// Weighted summation // Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const size_t kv_offset = cache_pos * kCachePosSize +
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; layer * kCacheLayerSize + head_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
} }
}; };
@ -747,42 +761,66 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
// Multi-Head Attention // Multi-Head Attention
static_assert(TConfig::kInterleaveQKV); static_assert(TConfig::kInterleaveQKV);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT qkv = float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3; activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>( MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool); pool);
}
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { const size_t num_tasks = kHeads * num_tokens;
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3; pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2; layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(q, head, head * kQKVDim * 2, thread); });
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
}); });
} else { } else {
// Multi-Query Attention // Multi-Query Attention
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x, MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
activations.even_odd.data(), q, pool); activations.even_odd.data(), q, pool);
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize; layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w, MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
kHeads * kQKVDim * kModelDim, x, kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool); activations.even_odd.data(), kv, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { const size_t num_tasks = kHeads * num_tokens;
Attn(q + head * kQKVDim, head, 0, thread); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
}); });
} }
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights. // rearranging the weights.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
@ -802,19 +840,20 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
activations.even_odd.data(), head_out, pool); activations.even_odd.data(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim); AddFrom(head_out, layer_out, kModelDim);
} }
}
} }
template <size_t kBatchSize, typename LayerT, typename TConfig> template <size_t kBatchSize, typename LayerT, typename TConfig>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations, HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const LayerT* layer_weights, size_t num_tokens, const LayerT* layer_weights,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
float* HWY_RESTRICT even_odd = activations.even_odd.data(); float* HWY_RESTRICT even_odd = activations.even_odd.data();
{ for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
PROFILER_ZONE("Gen.FFW.GatedGELU"); PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec = const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
@ -839,11 +878,15 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
} }
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
PROFILER_ZONE("Gen.FFW\\GatedGELU"); PROFILER_ZONE("Gen.FFW\\GatedGELU");
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>( MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd, layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool); activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
} }
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` // `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
@ -898,24 +941,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim, activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
}
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations, Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool); layer_weights, kv_cache, pool);
} else { } else {
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations, GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool); layer_weights, kv_cache, pool);
} }
}
// TODO: sink the loop into these functions, i.e. make them MatMul. pool.Run(0, num_tokens, [&](const uint64_t token_idx,
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { size_t /*thread*/) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * kModelDim, AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
FFW<kBatchSize>(activations, token_idx, layer_weights, pool); });
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(activations.ffw_out.data() + token_idx * kModelDim, AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
} }
@ -957,16 +1002,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim); activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache, Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
pool); pool);
} else { } else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights, GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
kv_cache, pool); kv_cache, pool);
} }
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool); FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output != nullptr) {
std::string block_name = "blocks." + std::to_string(layer); std::string block_name = "blocks." + std::to_string(layer);