Cleanup: add wrapper functions and rename vars to interleaved

Simplifies the TransformerLayer function.
Use interleaved* instead of _and_queries.

PiperOrigin-RevId: 653929449
This commit is contained in:
Jan Wassenberg 2024-07-19 02:03:36 -07:00 committed by Copybara-Service
parent 12016d31c3
commit 5844e6a1e5
2 changed files with 156 additions and 129 deletions

View File

@ -63,6 +63,11 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`.
template <class TConfig> template <class TConfig>
HWY_NOINLINE void GriffinRecurrent( HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
@ -191,21 +196,21 @@ HWY_NOINLINE void GriffinRecurrent(
} }
template <class TConfig, typename T> template <class TConfig, typename T>
HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) { HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) {
constexpr size_t kQKVDim = TConfig::kQKVDim; constexpr size_t kQKVDim = TConfig::kQKVDim;
// PostQKType::Rope // PostQKType::Rope
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
} }
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
size_t num_queries, size_t layer, size_t num_queries, size_t layer,
Activations& activations, Activations& activations,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches, const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(batch_and_query_start % num_queries == 0); HWY_DASSERT(interleaved_start % num_queries == 0);
constexpr size_t kQKVDim = TConfig::kQKVDim; constexpr size_t kQKVDim = TConfig::kQKVDim;
constexpr size_t kQStride = Activations::QStride<TConfig>(); constexpr size_t kQStride = Activations::QStride<TConfig>();
constexpr size_t kCachePosSize = CachePosSize<TConfig>()(); constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
@ -218,8 +223,8 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// Multi-Head Attention a.k.a. "use_qkv_einsum". // Multi-Head Attention a.k.a. "use_qkv_einsum".
constexpr bool kIsMHA = Activations::IsMHA<TConfig>(); constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
const size_t batch_start = batch_and_query_start / num_queries; const size_t batch_start = interleaved_start / num_queries;
const size_t num_tokens_and_queries = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
// For the computation of Q, K, and V, it is useful to remember that // For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
@ -229,16 +234,16 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// If MHA, this also computes KV, which we copy to the KV cache below. // If MHA, this also computes KV, which we copy to the KV cache below.
const float scale = layer_weights->qkv_einsum_w.scale(); const float scale = layer_weights->qkv_einsum_w.scale();
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>( MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens_and_queries, activations.pre_att_rms_out.All(), num_interleaved, activations.pre_att_rms_out.All(),
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool); layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
// Compute KV if not MHA. // Compute KV if not MHA.
if constexpr (!kIsMHA) { if constexpr (!kIsMHA) {
for (size_t batch_and_query_idx = 0; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { ++interleaved_idx) {
const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx); const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = batch_and_query_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = batch_and_query_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = *kv_caches[query_idx];
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
@ -255,12 +260,12 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// Apply positional encodings for K (and copy KV to cache if MHA). // Apply positional encodings for K (and copy KV to cache if MHA).
pool.Run( pool.Run(
0, kKVHeads * num_tokens_and_queries, 0, kKVHeads * num_interleaved,
[&](uint64_t task, size_t thread) HWY_ATTR { [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kKVHeads; const size_t head = task % kKVHeads;
const size_t batch_and_query_idx = task / kKVHeads; const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = batch_and_query_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = batch_and_query_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
@ -270,7 +275,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
if constexpr (kIsMHA) { if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above). // For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q = const float* HWY_RESTRICT q =
activations.q.Batch(batch_and_query_idx) + head * kQStride; activations.q.Batch(interleaved_idx) + head * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`. // Skip past the Q part of `q`, and copy KV to `kv`.
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float)); hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
} }
@ -281,69 +286,70 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
"query heads must be a multiple of key-value heads"); "query heads must be a multiple of key-value heads");
constexpr size_t kGroupHeads = kHeads / kKVHeads; constexpr size_t kGroupHeads = kHeads / kKVHeads;
// For each head (token, query), compute Q.K, softmax, and weighted V. // For each head (token, query), compute Q.K, softmax, and weighted V.
pool.Run(0, kHeads * num_tokens_and_queries, pool.Run(
[&](uint64_t task, size_t thread) HWY_ATTR { 0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads; const size_t head = task % kHeads;
const size_t batch_and_query_idx = task / kHeads; const size_t interleaved_idx = task / kHeads;
const size_t query_idx = batch_and_query_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = batch_and_query_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = *kv_caches[query_idx];
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.Batch(batch_and_query_idx) + head * kQStride; activations.q.Batch(interleaved_idx) + head * kQStride;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
PostQK<TConfig>(q, pos, layer); PostQK<TConfig>(q, pos, layer);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q.K scores, yielding "logits" (or scores) in head_att. // Compute Q.K scores, yielding "logits" (or scores) in head_att.
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations.att.Batch(batch_and_query_idx) + head * kSeqLen; activations.att.Batch(interleaved_idx) + head * kSeqLen;
const size_t start_pos = const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim); const float score = Dot(q, k2, kQKVDim);
head_att[pos2 % kSeqLen] = score; head_att[pos2 % kSeqLen] = score;
} }
// SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att. // SoftMax. May be preceded by SoftCap. Yields "probabilities" in
const size_t head_att_len = std::min(pos + 1, kSeqLen); // head_att.
if constexpr (TConfig::kAttCap > 0.0f) { const size_t head_att_len = std::min(pos + 1, kSeqLen);
LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); if constexpr (TConfig::kAttCap > 0.0f) {
} LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
Softmax(head_att, head_att_len); }
Softmax(head_att, head_att_len);
// Summation of v (kv_cache) weighted by probs (head_att) // Summation of v (kv_cache) weighted by probs (head_att)
// into "encoded" (att_out). Compare gemma/modules.py: // into "encoded" (att_out). Compare gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim; activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; float* HWY_RESTRICT v2 =
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); kv_cache.kv_cache.get() + kv_offset + kQKVDim;
} MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
}); }
});
// Sum encoded (att_out) over num_heads and head_dim (kQKVDim) // Sum encoded (att_out) over num_heads and head_dim (kQKVDim)
// into output (layer_out). Compare gemma/modules.py: // into output (layer_out). Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
for (size_t batch_and_query_idx = 0; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { ++interleaved_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights. // rearranging the weights.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx);
activations.att_out.Batch(batch_and_query_idx);
float* HWY_RESTRICT layer_out = float* HWY_RESTRICT layer_out =
activations.att_post2.Batch(batch_and_query_idx); activations.att_post2.Batch(interleaved_idx);
// Head 0 (and potentially biases) -> layer_out. // Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>( MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
@ -364,6 +370,28 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
} }
} }
template <class TConfig>
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
size_t num_tokens, size_t num_queries, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
activations, layer_weights, kv_caches, pool);
} else {
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
if constexpr (TConfig::kGriffinLayers > 0) {
HWY_ASSERT(num_queries == 1);
GriffinRecurrent<TConfig>(interleaved_start, num_tokens, num_queries,
layer, activations, layer_weights, kv_caches,
pool);
}
}
}
template <class TConfig, typename T> template <class TConfig, typename T>
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
size_t count) { size_t count) {
@ -377,7 +405,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
} }
template <class TConfig> template <class TConfig>
HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW");
@ -388,7 +416,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
// elements in memory, repeated kFFHiddenDim times. // elements in memory, repeated kFFHiddenDim times.
constexpr size_t kColsA = kModelDim; constexpr size_t kColsA = kModelDim;
constexpr size_t kColsB = kFFHiddenDim; constexpr size_t kColsB = kFFHiddenDim;
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize()); HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All(); const auto A = activations.bf_pre_ffw_rms_out.All();
const float scale = layer_weights->gating_einsum_w.scale(); const float scale = layer_weights->gating_einsum_w.scale();
const auto B1 = layer_weights->gating_einsum_w.data(); const auto B1 = layer_weights->gating_einsum_w.data();
@ -400,18 +428,18 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
const auto bias2 = bias1 + kFFHiddenDim; const auto bias2 = bias1 + kFFHiddenDim;
// Will go through GELU. // Will go through GELU.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, scale, C1, MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B1, scale,
bias1, pool); C1, bias1, pool);
// What to multiply by. // What to multiply by.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, scale, C2, MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B2, scale,
bias2, pool); C2, bias2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1. // Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens); Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>( MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
num_tokens, C1, layer_weights->linear_w.data(), num_interleaved, C1, layer_weights->linear_w.data(),
layer_weights->linear_w.scale(), activations.ffw_out.All(), layer_weights->linear_w.scale(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data_scale1(), pool); layer_weights->ffw_output_biases.data_scale1(), pool);
} }
@ -440,11 +468,18 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
template <class TConfig, typename T> template <class TConfig, typename T>
HWY_NOINLINE void ResidualConnection( HWY_NOINLINE void ResidualConnection(
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x, size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
const CompressedLayer<TConfig>* layer_weights, bool is_attention) { const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
// ResidualType::Add // ResidualType::Add
AddFromBatched(num_tokens_and_queries, other, x, kModelDim); AddFromBatched(num_interleaved, other, x, kModelDim);
}
template <class TConfig, typename WeightT, typename InOutT>
void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) {
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim);
}
} }
template <class TConfig> template <class TConfig>
@ -453,46 +488,37 @@ HWY_NOINLINE void TransformerLayer(
const CompressedLayer<TConfig>* layer_weights, Activations& activations, const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) { const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_tokens_and_queries = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
size_t layer_of_type = size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_attention_norm_scale.data_scale1(), layer_weights->pre_attention_norm_scale.data_scale1(),
activations.pre_att_rms_out.All(), kModelDim); activations.pre_att_rms_out.All(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
layer_weights, kv_caches, pool);
} else {
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
if constexpr (TConfig::kGriffinLayers > 0) {
HWY_ASSERT(num_queries == 1);
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
activations, layer_weights, kv_caches, pool);
}
}
if (TConfig::kPostNorm == PostNormType::Scale) { Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
RMSNormInplaceBatched( activations, layer_weights, kv_caches, pool);
num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data_scale1(),
activations.att_post2.All(), kModelDim);
}
ResidualConnection<TConfig>(num_tokens_and_queries, PostNorm<TConfig>(num_interleaved,
activations.att_post2.All(), activations.x.All(), layer_weights->post_attention_norm_scale.data_scale1(),
layer_weights, /*is_attention=*/true); activations.att_post2.All());
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
ResidualConnection<TConfig>(num_interleaved, activations.att_post2.All(),
activations.x.All(), layer_weights,
/*is_attention=*/true);
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_ffw_norm_scale.data_scale1(), layer_weights->pre_ffw_norm_scale.data_scale1(),
activations.bf_pre_ffw_rms_out.All(), kModelDim); activations.bf_pre_ffw_rms_out.All(), kModelDim);
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNorm == PostNormType::Scale) { FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_ffw_norm_scale.data_scale1(), PostNorm<TConfig>(num_interleaved,
activations.ffw_out.All(), kModelDim); layer_weights->post_ffw_norm_scale.data_scale1(),
} activations.ffw_out.All());
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
ResidualConnection<TConfig>(num_interleaved, activations.ffw_out.All(),
activations.x.All(), layer_weights, activations.x.All(), layer_weights,
/*is_attention=*/false); /*is_attention=*/false);
} }
@ -669,16 +695,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
const std::vector<KVCache*>& kv_caches, const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool, hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) { const LayersOutputFunc& layers_output) {
const size_t num_tokens_and_queries = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
++token_idx) {
float token_f = tokens[token_idx]; float token_f = tokens[token_idx];
layers_output(pos + token_idx, "Tokens", &token_f, 1); layers_output(pos + token_idx, "Tokens", &token_f, 1);
} }
} }
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights, EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
activations); activations);
} }
@ -690,20 +715,17 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
if (layers_output) { if (layers_output) {
const std::string block_name = "blocks." + std::to_string(layer); const std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
++token_idx) {
layers_output(pos + token_idx, block_name, layers_output(pos + token_idx, block_name,
activations.x.Batch(token_idx), kModelDim); activations.x.Batch(token_idx), kModelDim);
} }
} }
} }
RMSNormInplaceBatched(num_tokens_and_queries, RMSNormInplaceBatched(num_interleaved, weights.final_norm_scale.data_scale1(),
weights.final_norm_scale.data_scale1(),
activations.x.All(), kModelDim); activations.x.All(), kModelDim);
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
++token_idx) {
layers_output(pos + token_idx, "final_norm", layers_output(pos + token_idx, "final_norm",
activations.x.Batch(token_idx), kModelDim); activations.x.Batch(token_idx), kModelDim);
} }

View File

@ -268,17 +268,24 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_FUNC("gr_a", griffin.a); GEMMA_CALL_FUNC("gr_a", griffin.a);
} }
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
// For conditionally-included tensors, the else branch must ensure their
// scale is initialized, because wrapper functions call data_scale1 even if
// the tensor turns out to be unused. If unused, the arrays are zero-length
// and data() returns a non-null but unusable pointer.
if (TConfig::kPostNorm == PostNormType::Scale) { if (TConfig::kPostNorm == PostNormType::Scale) {
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
} else {
c_layer->post_attention_norm_scale.set_scale(1.0f);
c_layer->post_ffw_norm_scale.set_scale(1.0f);
} }
if (TConfig::kFFBiases) { if (TConfig::kFFBiases) {
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
} else { } else {
// Ensure initialized so we can call data_scale1, which happens even if
// the tensor turns out to be unused.
c_layer->ffw_gating_biases.set_scale(1.0f); c_layer->ffw_gating_biases.set_scale(1.0f);
c_layer->ffw_output_biases.set_scale(1.0f); c_layer->ffw_output_biases.set_scale(1.0f);
} }
@ -287,8 +294,6 @@ void ForEachTensor(RawWeightsPtr raw_weights,
if (TConfig::kSoftmaxAttnOutputBiases) { if (TConfig::kSoftmaxAttnOutputBiases) {
GEMMA_CALL_FUNC("attn_ob", attention_output_biases); GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
} else { } else {
// Ensure initialized so we can call data_scale1, which happens even if
// the tensor turns out to be unused.
c_layer->attention_output_biases.set_scale(1.0f); c_layer->attention_output_biases.set_scale(1.0f);
} }
} }