diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 28de0c8..40a7ac5 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -460,7 +460,7 @@ KVCache CreateKVCacheT() { constexpr size_t kConv1dWidth = Config::kConv1dWidth; return CreateKVCache( Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen, + Config::kSeqLen + kPrefillBatchSize, Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * Config::kModelDim, Config::kGriffinLayers * Config::kModelDim); @@ -569,34 +569,39 @@ namespace HWY_NAMESPACE { template HWY_NOINLINE void GriffinRecurrent( - size_t batch_start, size_t batch_idx, size_t layer, + size_t batch_start, size_t num_tokens, size_t layer, Activations& activations, const LayerT* layer_weights, KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Griffin"); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - HWY_DASSERT(batch_idx < kBatchSize); + HWY_DASSERT(num_tokens <= kBatchSize); static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kHeads = TConfig::kHeads; static constexpr bool kAdd = true; - const size_t batch_offset = batch_idx * kModelDim; - const size_t pos = batch_start + batch_idx; // X / Y linear layers. - float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - TwoMatVecAdd( - layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, - activations.pre_att_rms_out.data() + batch_offset, - /*add0=*/layer_weights->griffin.linear_x_biases.data(), - /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, - /*out1=*/y, pool); - Gelu(y, kModelDim); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + TwoMatVecAdd( + layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, + activations.pre_att_rms_out.data() + batch_offset, + /*add0=*/layer_weights->griffin.linear_x_biases.data(), + /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, + /*out1=*/y, pool); + Gelu(y, kModelDim); + } // Conv1D. - { + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; HWY_FULL(float) df; HWY_DASSERT(kModelDim % Lanes(df) == 0); const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); @@ -611,14 +616,15 @@ HWY_NOINLINE void GriffinRecurrent( } for (size_t i = 0; i < kModelDim; i += Lanes(df)) { auto xv = hn::Load(df, x + i); - auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); + auto accum0 = + hn::Load(df, layer_weights->griffin.conv_biases.data() + i); auto accum1 = hn::Zero(df); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); for (size_t l = 0; 2 * l < kConv1dWidth; l++) { auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + - (kConv1dWidth - 1 - 2 * l) * kModelDim + i); + (kConv1dWidth - 1 - 2 * l) * kModelDim + i); auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + - (kConv1dWidth - 2 - 2 * l) * kModelDim + i); + (kConv1dWidth - 2 - 2 * l) * kModelDim + i); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); } @@ -628,68 +634,80 @@ HWY_NOINLINE void GriffinRecurrent( } // RGLRU - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset; - float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset; - float* HWY_RESTRICT rnn_state = - kv_cache.rglru_cache.get() + layer * kModelDim; + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* HWY_RESTRICT gate_x = + activations.griffin_gate_x.data() + batch_offset; + float* HWY_RESTRICT a = + activations.griffin_multiplier.data() + batch_offset; + float* HWY_RESTRICT rnn_state = + kv_cache.rglru_cache.get() + layer * kModelDim; - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - constexpr size_t kHeadDim = kModelDim / kHeads; - constexpr size_t kMatrixSize = kHeadDim * kHeadDim; - size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( - layer_weights->griffin.gate_w, kMatrixSize * head, - kMatrixSize * (kHeads + head), x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, - /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + - head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.data() + head_offset, fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_FULL(float) df; - HWY_DASSERT(kHeadDim % Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0); + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + constexpr size_t kHeadDim = kModelDim / kHeads; + constexpr size_t kMatrixSize = kHeadDim * kHeadDim; + size_t head_offset = head * kHeadDim; + TwoOfsMatVecAddLoop( + layer_weights->griffin.gate_w, kMatrixSize * head, + kMatrixSize * (kHeads + head), x + head_offset, + /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + Sigmoid(gate_x + head_offset, kHeadDim); + Sigmoid(a + head_offset, kHeadDim); + const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) + HWY_ATTR { return hn::Mul(x, gate_x); }; + hn::Transform1(D(), a + head_offset, kHeadDim, + layer_weights->griffin.a.data() + head_offset, fn_mul); + hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, + fn_mul); + // RNN scan + HWY_FULL(float) df; + HWY_DASSERT(kHeadDim % Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += Lanes(df)) { + auto log_a = hn::Load(df, a + head_offset + i); + auto gated_x = hn::Load(df, x + head_offset + i); + auto rnn = hn::Load(df, rnn_state + head_offset + i); + auto a = hn::Exp(df, log_a); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0))); + if (pos == 0) { + x_multiplier = hn::Set(df, 1.0); + } + auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); + hn::Store(new_x, df, rnn_state + head_offset + i); + + // Join branches. + auto yv = hn::Load(df, y + head_offset + i); + auto pre_out = hn::Mul(yv, new_x); + hn::Store(pre_out, df, x + head_offset + i); } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); + }); + } // Final linear layer. - float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; - MatVecAdd( - layer_weights->griffin.linear_out_w, 0, x, - layer_weights->griffin.linear_out_biases.data(), - activations.even_odd.data(), out_ptr, pool); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; + MatVecAdd( + layer_weights->griffin.linear_out_w, 0, x, + layer_weights->griffin.linear_out_biases.data(), + activations.even_odd.data(), out_ptr, pool); + } } template -HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, +HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, Activations& activations, const LayerT* layer_weights, KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); - const size_t pos = batch_start + batch_idx; - HWY_DASSERT(batch_idx < kBatchSize); + HWY_DASSERT(num_tokens <= kBatchSize); static constexpr size_t kQKVDim = gcpp::Activations::kQKVDim; static constexpr size_t kCachePosSize = gcpp::Activations::kCachePosSize; @@ -699,47 +717,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kSeqLen = TConfig::kSeqLen; static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); - size_t cache_pos = pos; - size_t cache_num = pos + 1; - if constexpr (TConfig::kUseLocalAttention) { - cache_pos %= TConfig::kSeqLen; - cache_num = std::min(cache_num, static_cast(TConfig::kSeqLen)); - } - - float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - - auto Attn = [&](float* q, uint64_t head, size_t head_offset, + auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx, size_t thread) HWY_ATTR { + const size_t pos = batch_start + batch_idx; // Calculate scores float* HWY_RESTRICT head_att = activations.att.data() + - head * TConfig::kSeqLen + - batch_idx * kHeads * kQKVDim; + head * kSeqLen + + batch_idx * kHeads * kSeqLen; Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores - for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset; + const size_t start_pos = pos - std::min(kSeqLen - 1, pos); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; const float score = Dot(q, k2, kQKVDim); - head_att[pos2] = score; + head_att[pos2 % kSeqLen] = score; } - Softmax(head_att, cache_num); + Softmax(head_att, std::min(pos + 1, kSeqLen)); // Weighted summation float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head_offset; + float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); } }; @@ -747,74 +761,99 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, // Multi-Head Attention static_assert(TConfig::kInterleaveQKV); - float* HWY_RESTRICT qkv = - activations.q.data() + batch_idx * kHeads * kQKVDim * 3; - MatVec( - layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, - pool); - - pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - float* HWY_RESTRICT q = qkv + head * kQKVDim * 3; + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; + float* HWY_RESTRICT qkv = + activations.q.data() + batch_idx * kHeads * kQKVDim * 3; + MatVec( + layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, + pool); + } + const size_t num_tasks = kHeads * num_tokens; + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t batch_idx = task / kHeads; + const size_t pos = batch_start + batch_idx; + float* HWY_RESTRICT q = + activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - Attn(q, head, head * kQKVDim * 2, thread); + }); + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t batch_idx = task / kHeads; + float* HWY_RESTRICT q = + activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; + Attn(q, head, head * kQKVDim * 2, batch_idx, thread); }); } else { // Multi-Query Attention - float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; - MatVec(layer_weights->qkv_einsum_w, 0, x, - activations.even_odd.data(), q, pool); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t pos = batch_start + batch_idx; + float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + - cache_pos * kCachePosSize + - layer * kCacheLayerSize; - MatVec(layer_weights->qkv_einsum_w, - kHeads * kQKVDim * kModelDim, x, - activations.even_odd.data(), kv, pool); + float* HWY_RESTRICT q = + activations.q.data() + batch_idx * kHeads * kQKVDim; + MatVec(layer_weights->qkv_einsum_w, 0, x, + activations.even_odd.data(), q, pool); - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - - pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - Attn(q + head * kQKVDim, head, 0, thread); + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + MatVec(layer_weights->qkv_einsum_w, + kHeads * kQKVDim * kModelDim, x, + activations.even_odd.data(), kv, pool); + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + } + const size_t num_tasks = kHeads * num_tokens; + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t batch_idx = task / kHeads; + float* HWY_RESTRICT q = + activations.q.data() + batch_idx * kHeads * kQKVDim; + Attn(q + head * kQKVDim, head, 0, batch_idx, thread); }); } - // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after - // rearranging the weights. - float* HWY_RESTRICT att_out = - activations.att_out.data() + batch_idx * kHeads * kQKVDim; - float* HWY_RESTRICT layer_out = - activations.att_post2.data() + batch_idx * kModelDim; - MatVecAdd( - layer_weights->attn_vec_einsum_w, 0, att_out, - layer_weights->attention_output_biases.data(), - activations.even_odd.data(), layer_out, pool); - for (size_t head = 1; head < kHeads; ++head) { - float* HWY_RESTRICT head_out = - activations.att_post1.data() + head * kBatchSize * kModelDim; - MatVec( - layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, - att_out + head * kQKVDim, - activations.even_odd.data(), head_out, pool); - AddFrom(head_out, layer_out, kModelDim); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after + // rearranging the weights. + float* HWY_RESTRICT att_out = + activations.att_out.data() + batch_idx * kHeads * kQKVDim; + float* HWY_RESTRICT layer_out = + activations.att_post2.data() + batch_idx * kModelDim; + MatVecAdd( + layer_weights->attn_vec_einsum_w, 0, att_out, + layer_weights->attention_output_biases.data(), + activations.even_odd.data(), layer_out, pool); + for (size_t head = 1; head < kHeads; ++head) { + float* HWY_RESTRICT head_out = + activations.att_post1.data() + head * kBatchSize * kModelDim; + MatVec( + layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, + att_out + head * kQKVDim, + activations.even_odd.data(), head_out, pool); + AddFrom(head_out, layer_out, kModelDim); + } } } template HWY_NOINLINE void FFW(Activations& activations, - size_t batch_idx, const LayerT* layer_weights, + size_t num_tokens, const LayerT* layer_weights, hwy::ThreadPool& pool) { HWY_DASSERT(batch_idx < kBatchSize); static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; float* HWY_RESTRICT even_odd = activations.even_odd.data(); - { + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; PROFILER_ZONE("Gen.FFW.GatedGELU"); const hwy::bfloat16_t* HWY_RESTRICT vec = activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; @@ -839,11 +878,15 @@ HWY_NOINLINE void FFW(Activations& activations, HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); } - PROFILER_ZONE("Gen.FFW\\GatedGELU"); - MatVecAdd( - layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), even_odd, - activations.ffw_out.data() + batch_idx * kModelDim, pool); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + PROFILER_ZONE("Gen.FFW\\GatedGELU"); + const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + MatVecAdd( + layer_weights->linear_w, 0, + activations.ffw_hidden.data() + hidden_offset, + layer_weights->ffw_output_biases.data(), even_odd, + activations.ffw_out.data() + batch_idx * kModelDim, pool); + } } // `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` @@ -898,24 +941,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, layer_weights->pre_attention_norm_scale.data(), activations.pre_att_rms_out.data() + token_idx * kModelDim, kModelDim); - if (type == LayerAttentionType::kGemma) { - Attention(pos, token_idx, layer_of_type, activations, - layer_weights, kv_cache, pool); - } else { - GriffinRecurrent(pos, token_idx, layer_of_type, activations, - layer_weights, kv_cache, pool); - } + } + if (type == LayerAttentionType::kGemma) { + Attention(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); + } else { + GriffinRecurrent(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); } - // TODO: sink the loop into these functions, i.e. make them MatMul. - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + pool.Run(0, num_tokens, [&](const uint64_t token_idx, + size_t /*thread*/) HWY_ATTR { AddFrom(activations.att_post2.data() + token_idx * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); RMSNorm(activations.x.data() + token_idx * kModelDim, layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, kModelDim); - FFW(activations, token_idx, layer_weights, pool); + }); + FFW(activations, num_tokens, layer_weights, pool); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { AddFrom(activations.ffw_out.data() + token_idx * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); } @@ -957,16 +1002,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, layer_weights->pre_attention_norm_scale.data(), activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { - Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache, + Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, pool); } else { - GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights, + GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, pool); } AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool); + FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); if (layers_output != nullptr) { std::string block_name = "blocks." + std::to_string(layer);