Merge pull request #177 from szabadka:gemma2

PiperOrigin-RevId: 630388843
This commit is contained in:
Copybara-Service 2024-05-03 07:52:27 -07:00
commit 8ed22e52bf
2 changed files with 198 additions and 156 deletions

View File

@ -247,7 +247,6 @@ struct CompressTraits<hwy::bfloat16_t> {
using VU32 = hn::VFromD<decltype(du32)>; using VU32 = hn::VFromD<decltype(du32)>;
const VU32 odd = Set(du32, 0xFFFF0000u); const VU32 odd = Set(du32, 0xFFFF0000u);
VF32 be0, bo0, be1, bo1;
for (size_t i = 0; i < num; /* i += 2 * N */) { for (size_t i = 0; i < num; /* i += 2 * N */) {
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i); const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae0 = Load(df32, vec_aligned + i); const VF32 ae0 = Load(df32, vec_aligned + i);

View File

@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth; constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache( return CreateKVCache(
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen, Config::kSeqLen + kPrefillBatchSize,
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim, Config::kModelDim,
Config::kGriffinLayers * Config::kModelDim); Config::kGriffinLayers * Config::kModelDim);
@ -569,34 +569,38 @@ namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent( HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer, size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights, Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) { KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim = static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr bool kAdd = true; static constexpr bool kAdd = true;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers. // X / Y linear layers.
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; const size_t batch_offset = batch_idx * kModelDim;
TwoMatVecAdd<kAdd, kModelDim, kModelDim>( float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
activations.pre_att_rms_out.data() + batch_offset, TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
/*add0=*/layer_weights->griffin.linear_x_biases.data(), layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, activations.pre_att_rms_out.data() + batch_offset,
/*out1=*/y, pool); /*add0=*/layer_weights->griffin.linear_x_biases.data(),
Gelu(y, kModelDim); /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
}
// Conv1D. // Conv1D.
{ for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
HWY_FULL(float) df; HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0); HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
@ -611,14 +615,15 @@ HWY_NOINLINE void GriffinRecurrent(
} }
for (size_t i = 0; i < kModelDim; i += Lanes(df)) { for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i); auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
auto accum1 = hn::Zero(df); auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) { for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i); (kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i); (kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
} }
@ -628,68 +633,79 @@ HWY_NOINLINE void GriffinRecurrent(
} }
// RGLRU // RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset; const size_t batch_offset = batch_idx * kModelDim;
float* HWY_RESTRICT rnn_state = const size_t pos = batch_start + batch_idx;
kv_cache.rglru_cache.get() + layer * kModelDim; float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT gate_x =
activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a =
activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads; constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim; constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim; size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>( TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
layer_weights->griffin.gate_w, kMatrixSize * head, layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset, kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
head_offset, head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim); Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim); Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x) const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); }; HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim, hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data() + head_offset, fn_mul); layer_weights->griffin.a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul); fn_mul);
// RNN scan // RNN scan
HWY_FULL(float) df; HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0); HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) { for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i); auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i); auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i); auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a); auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0))); auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) { if (pos == 0) {
x_multiplier = hn::Set(df, 1.0); x_multiplier = hn::Set(df, 1.0);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
} }
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); });
hn::Store(new_x, df, rnn_state + head_offset + i); }
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
// Final linear layer. // Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
MatVecAdd<kAdd, kModelDim, kModelDim>( const size_t batch_offset = batch_idx * kModelDim;
layer_weights->griffin.linear_out_w, 0, x, float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
layer_weights->griffin.linear_out_biases.data(), float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
activations.even_odd.data(), out_ptr, pool); MatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(),
activations.even_odd.data(), out_ptr, pool);
}
} }
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
const LayerT* layer_weights, KVCache& kv_cache, const LayerT* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx; HWY_DASSERT(num_tokens <= kBatchSize);
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim; static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t kCachePosSize = static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize; gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
@ -699,47 +715,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
gcpp::Activations<TConfig, kBatchSize>::kModelDim; gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static const float kQueryScale = static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos; auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
size_t thread) HWY_ATTR { size_t thread) HWY_ATTR {
const size_t pos = batch_start + batch_idx;
// Calculate scores // Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() + float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen + head * kSeqLen +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kSeqLen;
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
const size_t cache_offset = for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset; const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score; head_att[pos2 % kSeqLen] = score;
} }
Softmax(head_att, cache_num); Softmax(head_att, std::min(pos + 1, kSeqLen));
// Weighted summation // Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim; batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_offset = const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const size_t kv_offset = cache_pos * kCachePosSize +
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; layer * kCacheLayerSize + head_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
} }
}; };
@ -747,74 +759,99 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
// Multi-Head Attention // Multi-Head Attention
static_assert(TConfig::kInterleaveQKV); static_assert(TConfig::kInterleaveQKV);
float* HWY_RESTRICT qkv = for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
activations.q.data() + batch_idx * kHeads * kQKVDim * 3; float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
MatVec<kHeads * kQKVDim * 3, kModelDim>( float* HWY_RESTRICT qkv =
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
pool); MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { pool);
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3; }
const size_t num_tasks = kHeads * num_tokens;
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2; layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(q, head, head * kQKVDim * 2, thread); });
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
}); });
} else { } else {
// Multi-Query Attention // Multi-Query Attention
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x, const size_t pos = batch_start + batch_idx;
activations.even_odd.data(), q, pool); float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + float* HWY_RESTRICT q =
cache_pos * kCachePosSize + activations.q.data() + batch_idx * kHeads * kQKVDim;
layer * kCacheLayerSize; MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w, activations.even_odd.data(), q, pool);
kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { layer * kCacheLayerSize;
Attn(q + head * kQKVDim, head, 0, thread); float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
const size_t num_tasks = kHeads * num_tokens;
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
}); });
} }
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
// rearranging the weights. // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
float* HWY_RESTRICT att_out = // rearranging the weights.
activations.att_out.data() + batch_idx * kHeads * kQKVDim; float* HWY_RESTRICT att_out =
float* HWY_RESTRICT layer_out = activations.att_out.data() + batch_idx * kHeads * kQKVDim;
activations.att_post2.data() + batch_idx * kModelDim; float* HWY_RESTRICT layer_out =
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>( activations.att_post2.data() + batch_idx * kModelDim;
layer_weights->attn_vec_einsum_w, 0, att_out, MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attention_output_biases.data(), layer_weights->attn_vec_einsum_w, 0, att_out,
activations.even_odd.data(), layer_out, pool); layer_weights->attention_output_biases.data(),
for (size_t head = 1; head < kHeads; ++head) { activations.even_odd.data(), layer_out, pool);
float* HWY_RESTRICT head_out = for (size_t head = 1; head < kHeads; ++head) {
activations.att_post1.data() + head * kBatchSize * kModelDim; float* HWY_RESTRICT head_out =
MatVec<kModelDim, kQKVDim>( activations.att_post1.data() + head * kBatchSize * kModelDim;
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, MatVec<kModelDim, kQKVDim>(
att_out + head * kQKVDim, layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
activations.even_odd.data(), head_out, pool); att_out + head * kQKVDim,
AddFrom(head_out, layer_out, kModelDim); activations.even_odd.data(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
}
} }
} }
template <size_t kBatchSize, typename LayerT, typename TConfig> template <size_t kBatchSize, typename LayerT, typename TConfig>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations, HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const LayerT* layer_weights, size_t num_tokens, const LayerT* layer_weights,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < kBatchSize); HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
float* HWY_RESTRICT even_odd = activations.even_odd.data(); float* HWY_RESTRICT even_odd = activations.even_odd.data();
{ for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
PROFILER_ZONE("Gen.FFW.GatedGELU"); PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec = const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
@ -839,11 +876,15 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
} }
PROFILER_ZONE("Gen.FFW\\GatedGELU"); for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>( PROFILER_ZONE("Gen.FFW\\GatedGELU");
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
layer_weights->ffw_output_biases.data(), even_odd, MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
activations.ffw_out.data() + batch_idx * kModelDim, pool); layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
} }
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` // `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
@ -898,24 +939,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim, activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
if (type == LayerAttentionType::kGemma) { }
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations, if (type == LayerAttentionType::kGemma) {
layer_weights, kv_cache, pool); Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
} else { layer_weights, kv_cache, pool);
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations, } else {
layer_weights, kv_cache, pool); GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
} layer_weights, kv_cache, pool);
} }
// TODO: sink the loop into these functions, i.e. make them MatMul. pool.Run(0, num_tokens, [&](const uint64_t token_idx,
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { size_t /*thread*/) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * kModelDim, AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
kModelDim); kModelDim);
FFW<kBatchSize>(activations, token_idx, layer_weights, pool); });
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(activations.ffw_out.data() + token_idx * kModelDim, AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
} }
@ -957,16 +1000,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim); activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache, Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
pool); pool);
} else { } else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights, GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
kv_cache, pool); kv_cache, pool);
} }
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool); FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output != nullptr) {
std::string block_name = "blocks." + std::to_string(layer); std::string block_name = "blocks." + std::to_string(layer);