diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 1599627..6659923 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -845,7 +845,7 @@ template HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, const LayerT* layer_weights, hwy::ThreadPool& pool) { - HWY_DASSERT(batch_idx < kBatchSize); + HWY_DASSERT(num_tokens <= kBatchSize); static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; float* HWY_RESTRICT even_odd = activations.even_odd.data();