diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0b58c39..864a59f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -63,6 +63,11 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +// Different functions use different naming conventions for the number of +// tokens. Functions that are query-independent, such as RMSNorm*, call the +// count `num_interleaved`. Functions that are query-dependent, such as +// `Attention`, use separate `num_tokens` and `num_queries`. + template HWY_NOINLINE void GriffinRecurrent( size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, @@ -191,21 +196,21 @@ HWY_NOINLINE void GriffinRecurrent( } template -HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) { +HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) { constexpr size_t kQKVDim = TConfig::kQKVDim; // PostQKType::Rope - Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); } template -HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, - size_t num_queries, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - const std::vector& kv_caches, - hwy::ThreadPool& pool) { +HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, + size_t num_queries, size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); - HWY_DASSERT(batch_and_query_start % num_queries == 0); + HWY_DASSERT(interleaved_start % num_queries == 0); constexpr size_t kQKVDim = TConfig::kQKVDim; constexpr size_t kQStride = Activations::QStride(); constexpr size_t kCachePosSize = CachePosSize()(); @@ -218,8 +223,8 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, // Multi-Head Attention a.k.a. "use_qkv_einsum". constexpr bool kIsMHA = Activations::IsMHA(); static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved - const size_t batch_start = batch_and_query_start / num_queries; - const size_t num_tokens_and_queries = num_tokens * num_queries; + const size_t batch_start = interleaved_start / num_queries; + const size_t num_interleaved = num_tokens * num_queries; // For the computation of Q, K, and V, it is useful to remember that // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] @@ -229,16 +234,16 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, // If MHA, this also computes KV, which we copy to the KV cache below. const float scale = layer_weights->qkv_einsum_w.scale(); MatMul_4x4_Batch( - num_tokens_and_queries, activations.pre_att_rms_out.All(), + num_interleaved, activations.pre_att_rms_out.All(), layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool); // Compute KV if not MHA. if constexpr (!kIsMHA) { - for (size_t batch_and_query_idx = 0; - batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { - const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx); - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; KVCache& kv_cache = *kv_caches[query_idx]; const size_t pos = batch_start + batch_idx; const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); @@ -255,12 +260,12 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, // Apply positional encodings for K (and copy KV to cache if MHA). pool.Run( - 0, kKVHeads * num_tokens_and_queries, + 0, kKVHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kKVHeads; - const size_t batch_and_query_idx = task / kKVHeads; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; + const size_t interleaved_idx = task / kKVHeads; + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; const size_t pos = batch_start + batch_idx; const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + @@ -270,7 +275,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, if constexpr (kIsMHA) { // For MHA, copy KV into the KV cache from scratch space (see above). const float* HWY_RESTRICT q = - activations.q.Batch(batch_and_query_idx) + head * kQStride; + activations.q.Batch(interleaved_idx) + head * kQStride; // Skip past the Q part of `q`, and copy KV to `kv`. hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float)); } @@ -281,69 +286,70 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, "query heads must be a multiple of key-value heads"); constexpr size_t kGroupHeads = kHeads / kKVHeads; // For each head (token, query), compute Q.K, softmax, and weighted V. - pool.Run(0, kHeads * num_tokens_and_queries, - [&](uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t batch_and_query_idx = task / kHeads; - const size_t query_idx = batch_and_query_idx % num_queries; - const size_t batch_idx = batch_and_query_idx / num_queries; - const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; - KVCache& kv_cache = *kv_caches[query_idx]; - float* HWY_RESTRICT q = - activations.q.Batch(batch_and_query_idx) + head * kQStride; + pool.Run( + 0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t interleaved_idx = task / kHeads; + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; + const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; + KVCache& kv_cache = *kv_caches[query_idx]; + float* HWY_RESTRICT q = + activations.q.Batch(interleaved_idx) + head * kQStride; - // Apply rope and scaling to Q. - const size_t pos = batch_start + batch_idx; - PostQK(q, pos, layer); - MulByConst(kQueryScale, q, kQKVDim); + // Apply rope and scaling to Q. + const size_t pos = batch_start + batch_idx; + PostQK(q, pos, layer); + MulByConst(kQueryScale, q, kQKVDim); - // Compute Q.K scores, yielding "logits" (or scores) in head_att. - float* HWY_RESTRICT head_att = - activations.att.Batch(batch_and_query_idx) + head * kSeqLen; - const size_t start_pos = - pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; - const float score = Dot(q, k2, kQKVDim); - head_att[pos2 % kSeqLen] = score; - } + // Compute Q.K scores, yielding "logits" (or scores) in head_att. + float* HWY_RESTRICT head_att = + activations.att.Batch(interleaved_idx) + head * kSeqLen; + const size_t start_pos = + pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; + const float score = Dot(q, k2, kQKVDim); + head_att[pos2 % kSeqLen] = score; + } - // SoftMax. May be preceded by SoftCap. Yields "probabilities" in head_att. - const size_t head_att_len = std::min(pos + 1, kSeqLen); - if constexpr (TConfig::kAttCap > 0.0f) { - LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); - } - Softmax(head_att, head_att_len); + // SoftMax. May be preceded by SoftCap. Yields "probabilities" in + // head_att. + const size_t head_att_len = std::min(pos + 1, kSeqLen); + if constexpr (TConfig::kAttCap > 0.0f) { + LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); + } + Softmax(head_att, head_att_len); - // Summation of v (kv_cache) weighted by probs (head_att) - // into "encoded" (att_out). Compare gemma/modules.py: - // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - float* HWY_RESTRICT att_out = - activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); - } - }); + // Summation of v (kv_cache) weighted by probs (head_att) + // into "encoded" (att_out). Compare gemma/modules.py: + // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) + float* HWY_RESTRICT att_out = + activations.att_out.Batch(interleaved_idx) + head * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + float* HWY_RESTRICT v2 = + kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); + } + }); // Sum encoded (att_out) over num_heads and head_dim (kQKVDim) // into output (layer_out). Compare gemma/modules.py: // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) - for (size_t batch_and_query_idx = 0; - batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // rearranging the weights. - float* HWY_RESTRICT att_out = - activations.att_out.Batch(batch_and_query_idx); + float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx); float* HWY_RESTRICT layer_out = - activations.att_post2.Batch(batch_and_query_idx); + activations.att_post2.Batch(interleaved_idx); // Head 0 (and potentially biases) -> layer_out. // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. MatVecT( @@ -364,6 +370,28 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, } } +template +HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, + size_t num_tokens, size_t num_queries, size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { + if (type == LayerAttentionType::kGemma) { + GemmaAttention(interleaved_start, num_tokens, num_queries, layer, + activations, layer_weights, kv_caches, pool); + } else { + // Only reached if the model is Griffin. `if constexpr` prevents generating + // this code for non-Griffin models. + if constexpr (TConfig::kGriffinLayers > 0) { + HWY_ASSERT(num_queries == 1); + GriffinRecurrent(interleaved_start, num_tokens, num_queries, + layer, activations, layer_weights, kv_caches, + pool); + } + } +} + template HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, size_t count) { @@ -377,7 +405,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, } template -HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, +HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, const CompressedLayer* layer_weights, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.FFW"); @@ -388,7 +416,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, // elements in memory, repeated kFFHiddenDim times. constexpr size_t kColsA = kModelDim; constexpr size_t kColsB = kFFHiddenDim; - HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize()); + HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); const auto A = activations.bf_pre_ffw_rms_out.All(); const float scale = layer_weights->gating_einsum_w.scale(); const auto B1 = layer_weights->gating_einsum_w.data(); @@ -400,18 +428,18 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, const auto bias2 = bias1 + kFFHiddenDim; // Will go through GELU. - MatMul_4x4_Batch_Add(num_tokens, A, B1, scale, C1, - bias1, pool); + MatMul_4x4_Batch_Add(num_interleaved, A, B1, scale, + C1, bias1, pool); // What to multiply by. - MatMul_4x4_Batch_Add(num_tokens, A, B2, scale, C2, - bias2, pool); + MatMul_4x4_Batch_Add(num_interleaved, A, B2, scale, + C2, bias2, pool); // Activation (Gelu) and multiply by gate. Store activations in C1. - Activation(C1, C2, kFFHiddenDim * num_tokens); + Activation(C1, C2, kFFHiddenDim * num_interleaved); // Hidden layer -> output layer. MatMul_4x4_Batch_Add( - num_tokens, C1, layer_weights->linear_w.data(), + num_interleaved, C1, layer_weights->linear_w.data(), layer_weights->linear_w.scale(), activations.ffw_out.All(), layer_weights->ffw_output_biases.data_scale1(), pool); } @@ -440,11 +468,18 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, template HWY_NOINLINE void ResidualConnection( - size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x, + size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x, const CompressedLayer* layer_weights, bool is_attention) { constexpr size_t kModelDim = TConfig::kModelDim; // ResidualType::Add - AddFromBatched(num_tokens_and_queries, other, x, kModelDim); + AddFromBatched(num_interleaved, other, x, kModelDim); +} + +template +void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) { + if (TConfig::kPostNorm == PostNormType::Scale) { + RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim); + } } template @@ -453,46 +488,37 @@ HWY_NOINLINE void TransformerLayer( const CompressedLayer* layer_weights, Activations& activations, const std::vector& kv_caches, hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; - const size_t num_tokens_and_queries = num_tokens * num_queries; + const size_t num_interleaved = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched(num_tokens_and_queries, activations.x.All(), + + RMSNormBatched(num_interleaved, activations.x.All(), layer_weights->pre_attention_norm_scale.data_scale1(), activations.pre_att_rms_out.All(), kModelDim); - if (type == LayerAttentionType::kGemma) { - Attention(pos, num_tokens, num_queries, layer_of_type, activations, - layer_weights, kv_caches, pool); - } else { - // Only reached if the model is Griffin. `if constexpr` prevents generating - // this code for non-Griffin models. - if constexpr (TConfig::kGriffinLayers > 0) { - HWY_ASSERT(num_queries == 1); - GriffinRecurrent(pos, num_tokens, num_queries, layer_of_type, - activations, layer_weights, kv_caches, pool); - } - } - if (TConfig::kPostNorm == PostNormType::Scale) { - RMSNormInplaceBatched( - num_tokens_and_queries, - layer_weights->post_attention_norm_scale.data_scale1(), - activations.att_post2.All(), kModelDim); - } + Attention(type, pos, num_tokens, num_queries, layer_of_type, + activations, layer_weights, kv_caches, pool); - ResidualConnection(num_tokens_and_queries, - activations.att_post2.All(), activations.x.All(), - layer_weights, /*is_attention=*/true); - RMSNormBatched(num_tokens_and_queries, activations.x.All(), + PostNorm(num_interleaved, + layer_weights->post_attention_norm_scale.data_scale1(), + activations.att_post2.All()); + + ResidualConnection(num_interleaved, activations.att_post2.All(), + activations.x.All(), layer_weights, + /*is_attention=*/true); + + RMSNormBatched(num_interleaved, activations.x.All(), layer_weights->pre_ffw_norm_scale.data_scale1(), activations.bf_pre_ffw_rms_out.All(), kModelDim); - FFW(activations, num_tokens_and_queries, layer_weights, pool); - if (TConfig::kPostNorm == PostNormType::Scale) { - RMSNormInplaceBatched(num_tokens_and_queries, - layer_weights->post_ffw_norm_scale.data_scale1(), - activations.ffw_out.All(), kModelDim); - } - ResidualConnection(num_tokens_and_queries, activations.ffw_out.All(), + + FFW(activations, num_interleaved, layer_weights, pool); + + PostNorm(num_interleaved, + layer_weights->post_ffw_norm_scale.data_scale1(), + activations.ffw_out.All()); + + ResidualConnection(num_interleaved, activations.ffw_out.All(), activations.x.All(), layer_weights, /*is_attention=*/false); } @@ -669,16 +695,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, const std::vector& kv_caches, hwy::ThreadPool& pool, const LayersOutputFunc& layers_output) { - const size_t num_tokens_and_queries = num_tokens * num_queries; + const size_t num_interleaved = num_tokens * num_queries; if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { + for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { float token_f = tokens[token_idx]; layers_output(pos + token_idx, "Tokens", &token_f, 1); } } constexpr size_t kModelDim = TConfig::kModelDim; - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); } @@ -690,20 +715,17 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, if (layers_output) { const std::string block_name = "blocks." + std::to_string(layer); - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { + for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { layers_output(pos + token_idx, block_name, activations.x.Batch(token_idx), kModelDim); } } } - RMSNormInplaceBatched(num_tokens_and_queries, - weights.final_norm_scale.data_scale1(), + RMSNormInplaceBatched(num_interleaved, weights.final_norm_scale.data_scale1(), activations.x.All(), kModelDim); if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens_and_queries; - ++token_idx) { + for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { layers_output(pos + token_idx, "final_norm", activations.x.Batch(token_idx), kModelDim); } diff --git a/gemma/weights.h b/gemma/weights.h index b6f72c2..cfc981c 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -268,17 +268,24 @@ void ForEachTensor(RawWeightsPtr raw_weights, GEMMA_CALL_FUNC("gr_a", griffin.a); } GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); + + // For conditionally-included tensors, the else branch must ensure their + // scale is initialized, because wrapper functions call data_scale1 even if + // the tensor turns out to be unused. If unused, the arrays are zero-length + // and data() returns a non-null but unusable pointer. + if (TConfig::kPostNorm == PostNormType::Scale) { GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); + } else { + c_layer->post_attention_norm_scale.set_scale(1.0f); + c_layer->post_ffw_norm_scale.set_scale(1.0f); } if (TConfig::kFFBiases) { GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); } else { - // Ensure initialized so we can call data_scale1, which happens even if - // the tensor turns out to be unused. c_layer->ffw_gating_biases.set_scale(1.0f); c_layer->ffw_output_biases.set_scale(1.0f); } @@ -287,8 +294,6 @@ void ForEachTensor(RawWeightsPtr raw_weights, if (TConfig::kSoftmaxAttnOutputBiases) { GEMMA_CALL_FUNC("attn_ob", attention_output_biases); } else { - // Ensure initialized so we can call data_scale1, which happens even if - // the tensor turns out to be unused. c_layer->attention_output_biases.set_scale(1.0f); } }