diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 23174a0..1707bad 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -126,6 +126,8 @@ TEST(OptimizeTest, GradientDescent) { info.model, prompt, gemma.Weights(), forward, inv_timescale, pool); CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, grad, backward, inv_timescale, pool); + CallForModelAndWeight( + info.model, info.weight, gemma.MutableWeights(), pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; diff --git a/gemma/activations.h b/gemma/activations.h index 72a0ab8..60d5610 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -74,10 +74,8 @@ struct Activations { RowVectorBatch pre_att_rms_out; RowVectorBatch att; // attention vector RowVectorBatch att_out; // attention output - // After linear transformation, shared by all heads - RowVectorBatch att_post1; // Accumulation of attention outputs over heads - RowVectorBatch att_post2; + RowVectorBatch att_sums; // Gated FFW RowVectorBatch bf_pre_ffw_rms_out; @@ -144,8 +142,7 @@ struct Activations { pre_att_rms_out = RowVectorBatch(batch_size, kModelDim); att = RowVectorBatch(batch_size, kHeads * kSeqLen); att_out = RowVectorBatch(batch_size, kHeads * kQKVDim); - att_post1 = RowVectorBatch(1, kModelDim); - att_post2 = RowVectorBatch(batch_size, kModelDim); + att_sums = RowVectorBatch(batch_size, kModelDim); bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); C1 = RowVectorBatch(batch_size, kFFHiddenDim); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 10d7350..c1c5ba4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -28,7 +28,6 @@ #include #include // std::min -#include // std::unique_ptr #include #include #include @@ -188,7 +187,7 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - float* out_ptr = activations.att_post2.Batch(batch_idx); + float* out_ptr = activations.att_sums.Batch(batch_idx); MatVecAdd( layer_weights->griffin.linear_out_w, 0, x, layer_weights->griffin.linear_out_biases.data_scale1(), @@ -421,39 +420,23 @@ class GemmaAttention { }); } - // Sums encoded (`att_out`) over num_heads and head_dim (kQKVDim) into output - // (`layer_out`). Compare gemma/modules.py: - // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) + // Sums encoded (`att_out`) over num_heads (`kHeads`) and head_dim (`kQKVDim`) + // into output (`layer_out`). HWY_NOINLINE void SumHeads(const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.SumHeads"); - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after - // rearranging the weights. - float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx); - float* HWY_RESTRICT layer_out = - activations_.att_post2.Batch(interleaved_idx); - // Head 0 (and potentially biases) -> layer_out. - // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. - constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases; - const float* bias = - kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr; - MatVecT( - layer_weights_.attn_vec_einsum_w, 0, att_out, bias, - activations_.even_odd.All(), layer_out, pool_); - // Head 1 and following are added to layer_out. - for (size_t head = 1; head < kHeads; ++head) { - // NOTE: this is a single kModelDim temp output. If parallelized or - // using MatMul, add per-thread storage. - float* HWY_RESTRICT head_out = activations_.att_post1.All(); - // TODO: requires MatMul support for offsets. - MatVec( - layer_weights_.attn_vec_einsum_w, head * kModelDim * kQKVDim, - att_out + head * kQKVDim, activations_.even_odd.All(), head_out, - pool_); - AddFrom(head_out, layer_out, kModelDim); - } - } + constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases; + const float* bias = + kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr; + + // att_weights and att_out are concatenated heads, each of length kQKVDim. + // Thus the [num_interleaved, kModelDim] matmul output is the sum over + // heads. Compare gemma/modules.py: + // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) + MatMul_4x4( + num_interleaved, MakeMat(activations_.att_out.All(), kHeads * kQKVDim), + MakeMat(layer_weights_.att_weights.data(), kHeads * kQKVDim), + layer_weights_.attn_vec_einsum_w.scale(), bias, + MakeMat(activations_.att_sums.All(), kModelDim), pool_); } public: @@ -634,9 +617,9 @@ HWY_NOINLINE void TransformerLayer( activations, layer_weights, div_seq_len, kv_caches, pool); PostNorm(num_interleaved, layer_weights->post_attention_norm_scale, - activations.att_post2.All()); + activations.att_sums.All()); - ResidualConnection(num_interleaved, activations.att_post2.All(), + ResidualConnection(num_interleaved, activations.att_sums.All(), activations.x.All(), layer_weights, /*is_attention=*/true); diff --git a/gemma/gemma.h b/gemma/gemma.h index 1363e1b..990449e 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -158,6 +158,7 @@ class Gemma { const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } + ByteStorageT& MutableWeights() { return weights_u8_; } void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info); diff --git a/gemma/weights.cc b/gemma/weights.cc index 6c675ad..a4d800e 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -72,6 +72,7 @@ struct LoadCompressedWeightsT { } HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); } + c_weights->Reshape(); return c_weights_u8; } }; diff --git a/gemma/weights.h b/gemma/weights.h index 6e0782d..d8d2ebb 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -24,6 +24,7 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" namespace gcpp { @@ -93,6 +94,33 @@ struct CompressedLayer { ArrayT ffw_gating_biases; ArrayT ffw_output_biases; + + // Reshaped attention; not loaded from disk via ForEachTensor. + ArrayT att_weights; + + // Initializes att_weights from attn_vec_einsum_w, hence this must be called + // after loading weights via ForEachTensor. + // TODO: update compression/convert_weights to bake this in. + void Reshape() { + PROFILER_ZONE("Startup.Reshape"); + + constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kHeads = TConfig::kHeads; + constexpr size_t kQKVDim = TConfig::kQKVDim; + + // Would have to implement a CompressTraits::Copy for NUQ. + static_assert(!hwy::IsSame()); + + // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. + for (size_t m = 0; m < kModelDim; ++m) { + Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim; + for (size_t h = 0; h < kHeads; ++h) { + hwy::CopyBytes( + attn_vec_einsum_w.data() + h * kModelDim * kQKVDim + m * kQKVDim, + out_row + h * kQKVDim, kQKVDim * sizeof(Weight)); + } + } + } }; // Array instead of single large allocation for parallel mem init. Split out @@ -135,6 +163,13 @@ struct CompressedWeights { explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {} + // Called by weights.cc after ForEachTensor. + void Reshape() { + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + GetLayer(layer)->Reshape(); + } + } + void ZeroInit() { hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding)); hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale)); @@ -174,6 +209,15 @@ struct ZeroInitCompressedWeights { } }; +template +struct ReshapeCompressedWeights { + void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { + CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); + weights.Reshape(); + } +}; + // TODO: also add RandInitCompressedWeights template