diff --git a/BUILD.bazel b/BUILD.bazel index 6ca9003..eab1476 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -242,13 +242,6 @@ cc_test( ], ) -cc_library( - name = "common", - srcs = ["gemma/common.cc"], - hdrs = ["gemma/common.h"], - deps = [":configs"], -) - # For building all tests in one command, so we can test several. test_suite( name = "ops_tests", @@ -566,11 +559,10 @@ cc_library( }, deps = [ ":benchmark_helper", - ":common", ":gemma_args", ":gemma_lib", ":kv_cache", - ":ops", + ":matmul", ":threading", ":threading_context", ":tokenizer", diff --git a/CMakeLists.txt b/CMakeLists.txt index e7dcfae..cb91f80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,8 +55,6 @@ set(SOURCES gemma/activations.h gemma/attention.cc gemma/attention.h - gemma/common.cc - gemma/common.h gemma/configs.cc gemma/configs.h gemma/gemma_args.h diff --git a/gemma/activations.h b/gemma/activations.h index 7874423..fa04c0b 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -41,13 +41,13 @@ static inline float ChooseQueryScale(const ModelConfig& config) { } struct Activations { - Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env) + Activations(const ModelConfig& config, size_t batch_size, + std::vector>& row_ptrs) : weights_config(config), layer_config(config.layer_configs[0]), seq_len(config.seq_len), cache_pos_size(config.CachePosSize()), is_griffin(config.model == Model::GRIFFIN_2B), - query_scale(ChooseQueryScale(config)), x("x", Extents2D(batch_size, config.model_dim), pad_), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA @@ -96,19 +96,19 @@ struct Activations { layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - env(env) { + query_scale(ChooseQueryScale(config)) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. - x.AllocateAndAttachRowPtrs(env->row_ptrs); - q.AllocateAndAttachRowPtrs(env->row_ptrs); - logits.AllocateAndAttachRowPtrs(env->row_ptrs); - att_sums.AllocateAndAttachRowPtrs(env->row_ptrs); - C1.AllocateAndAttachRowPtrs(env->row_ptrs); - C2.AllocateAndAttachRowPtrs(env->row_ptrs); - ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs); + x.AllocateAndAttachRowPtrs(row_ptrs); + q.AllocateAndAttachRowPtrs(row_ptrs); + logits.AllocateAndAttachRowPtrs(row_ptrs); + att_sums.AllocateAndAttachRowPtrs(row_ptrs); + C1.AllocateAndAttachRowPtrs(row_ptrs); + C2.AllocateAndAttachRowPtrs(row_ptrs); + ffw_out.AllocateAndAttachRowPtrs(row_ptrs); // Note that BindC on any MatMul output considerably slows down Prefill. } @@ -141,7 +141,6 @@ struct Activations { size_t seq_len; size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. bool is_griffin = false; - float query_scale; const Extents2D none_ = Extents2D(); const MatPadding pad_ = MatPadding::kOdd; @@ -172,7 +171,7 @@ struct Activations { MatStorageT inv_timescale; MatStorageT inv_timescale_global; - MatMulEnv* env; + float query_scale; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index c79ac89..f7f6ec1 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -222,7 +222,7 @@ static HWY_INLINE void ComputeQKV( size_t num_tokens, const QueriesPos& queries_pos, const hwy::Divisor& div_seq_len, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, const int flags, NestedPools& pools) { + const KVCaches& kv_caches, const int flags, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.QKV"); const size_t num_queries = queries_pos.size(); const size_t num_interleaved = num_tokens * num_queries; @@ -235,7 +235,7 @@ static HWY_INLINE void ComputeQKV( // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1, - /*add=*/nullptr, *activations.env, activations.q); + /*add=*/nullptr, env, activations.q); // Set up MatMul row pointers for writing to KV, which consists of // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound @@ -250,17 +250,16 @@ static HWY_INLINE void ComputeQKV( div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); const size_t kv_offset = cache_pos * cache_pos_size + layer_idx * cache_layer_size; - activations.env->storage.OutRow(interleaved_idx) = - reinterpret_cast(kv_caches[query_idx].kv_cache.get() + - kv_offset); + env.row_ptrs[0][interleaved_idx] = reinterpret_cast( + kv_caches[query_idx].kv_cache.get() + kv_offset); } - kv_rows.AttachRowPtrs(&activations.env->storage.OutRow(0)); + kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, - /*add=*/nullptr, *activations.env, kv_rows); + /*add=*/nullptr, env, kv_rows); // Apply positional encodings for K. // TODO: 2D parallelism to use more threads. - pools.Pool(0).Run( + env.ctx.pools.Pool(0).Run( 0, kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % kv_heads; @@ -289,7 +288,7 @@ static HWY_INLINE void ComputeQKV( // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - Activations& activations) { + Activations& activations, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.SumHeads"); const LayerConfig& layer_config = layer.layer_config; // att_weights and att_out are concatenated heads, each of length @@ -302,7 +301,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, const float* add = layer_config.softmax_attn_output_biases ? layer.attention_output_biases.PackedScale1() : nullptr; - CallMatMul(activations.att_out, layer.att_weights, add, *activations.env, + CallMatMul(activations.att_out, layer.att_weights, add, env, activations.att_sums); } @@ -312,7 +311,7 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, const QueriesPos* queries_prefix_end, const hwy::Divisor& div_seq_len, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, int flags) { + const KVCaches& kv_caches, MatMulEnv& env, int flags) { const size_t num_queries = queries_pos.size(); HWY_DASSERT(num_queries <= kv_caches.size()); @@ -331,13 +330,12 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, queries_prefix_end = &queries_prefix_end_span; } - NestedPools& pools = activations.env->ctx.pools; ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer, - activations, kv_caches, flags, pools); + activations, kv_caches, flags, env); DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, div_seq_len, layer_idx, layer, activations, kv_caches, - pools); - SumHeads(layer, activations); + env.ctx.pools); + SumHeads(layer, activations, env); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/attention.h b/gemma/attention.h index c8b527f..b43aeae 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -47,7 +47,7 @@ namespace gcpp { const QueriesPos* queries_prefix_end, \ const hwy::Divisor& div_seq_len, const size_t layer_idx, \ const LayerWeightsPtrs& layer, Activations& activations, \ - const KVCaches& kv_caches, int flags); \ + const KVCaches& kv_caches, MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index f3b295f..9b6fe94 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -29,7 +29,6 @@ #include #endif -#include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "ops/matmul.h" // MatMulEnv diff --git a/gemma/common.cc b/gemma/common.cc deleted file mode 100644 index 00cb05d..0000000 --- a/gemma/common.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gemma/common.h" - -#include - -#include - -#include "gemma/configs.h" - -namespace gcpp { - -void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) { - - // Instruction-tuned models are trained to expect control tokens. - if (config.wrapping == PromptWrapping::GEMMA_IT) { - // Prepend "" if this is a multi-turn dialogue continuation. - const std::string start = (pos == 0) - ? "user\n" - : "\nuser\n"; - prompt = start + prompt + "\nmodel\n"; - } -} - -} // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h deleted file mode 100644 index 37e903d..0000000 --- a/gemma/common.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ - -#include - -#include - -#include "gemma/configs.h" // IWYU pragma: export - -namespace gcpp { - -// Wraps the given prompt using the expected control tokens for IT models. -// DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine. -void Wrap(const ModelConfig& config, size_t pos, std::string& prompt); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9af7a16..e89a3ad 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -18,7 +18,10 @@ #include #include +#include "gemma/activations.h" #include "gemma/configs.h" +#include "gemma/weights.h" +#include "ops/matmul.h" #include "util/mat.h" #include "hwy/profiler.h" @@ -103,8 +106,8 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights, } } -static inline void FFWNoVit(Activations& activations, - const LayerWeightsPtrs& layer) { +static inline void FFWNoVit(const LayerWeightsPtrs& layer, + Activations& activations, MatMulEnv& env) { PROFILER_ZONE("Gen.FFW"); const LayerConfig& layer_config = layer.layer_config; const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; @@ -117,16 +120,16 @@ static inline void FFWNoVit(Activations& activations, add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, - *activations.env, activations.C1); - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, - *activations.env, activations.C2); + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env, + activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env, + activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_config.activation, activations.C1, &activations.C2); // Hidden layer -> output layer. - CallMatMul(activations.C1, layer.linear_w, output_bias, *activations.env, + CallMatMul(activations.C1, layer.linear_w, output_bias, env, activations.ffw_out); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6aadcaa..3f7c149 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -68,10 +68,10 @@ void Attention(LayerAttentionType type, size_t num_tokens, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches) { + const KVCaches& kv_caches, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len, - layer_idx, layer, activations, kv_caches, + layer_idx, layer, activations, kv_caches, env, /*flags=*/0); } else { HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); @@ -80,7 +80,7 @@ void Attention(LayerAttentionType type, size_t num_tokens, const size_t griffin_layer = activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, - &layer, kv_caches); + &layer, kv_caches, env); } } @@ -88,14 +88,14 @@ static HWY_NOINLINE void TransformerLayer( const size_t num_tokens, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, const KVCaches& kv_caches) { + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, activations.pre_att_rms_out); Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, - div_seq_len, layer_idx, layer, activations, kv_caches); + div_seq_len, layer_idx, layer, activations, kv_caches, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, activations.att_sums); @@ -107,9 +107,9 @@ static HWY_NOINLINE void TransformerLayer( activations.pre_ffw_rms_out); if (layer_config.type == LayerAttentionType::kVit) { - FFWVit(activations, layer); + FFWVit(layer, activations, env); } else { - FFWNoVit(activations, layer); + FFWNoVit(layer, activations, env); } PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, @@ -187,12 +187,11 @@ using QueriesMutablePos = hwy::Span; // Populates KV cache for batches of tokens from one query at a time. static HWY_NOINLINE void Prefill( - const QueriesPromptTokens& queries_prompt, + const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches) { + const hwy::Divisor& div_seq_len, const ModelConfig& config, + const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { PROFILER_ZONE("Gen.Prefill"); const size_t num_queries = queries_prompt.size(); HWY_DASSERT(queries_pos.size() == num_queries); @@ -263,7 +262,7 @@ static HWY_NOINLINE void Prefill( ++layer_idx) { TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, div_seq_len, layer_idx, *weights.GetLayer(layer_idx), - activations, single_kv_cache); + activations, single_kv_cache, env); } // NOTE: we unconditionally call StreamToken, even if EOS. @@ -298,9 +297,9 @@ static HWY_NOINLINE void Prefill( // from each query, and `queries_pos` are their position in the sequence. static HWY_NOINLINE void Transformer( const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, + const ModelConfig& config, const ModelWeightsPtrs& weights, + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, const LayersOutputFunc& layers_output, const ActivationsObserverFunc& activations_observer) { const size_t num_queries = queries_token.size(); @@ -323,7 +322,7 @@ static HWY_NOINLINE void Transformer( for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, div_seq_len, layer_idx, *weights.GetLayer(layer_idx), - activations, kv_caches); + activations, kv_caches, env); if (activations_observer) { activations_observer(queries_pos, layer_idx, activations); @@ -381,23 +380,22 @@ class TokenStreamer { // Runs one decode step for all the queries in the batch. Returns true if all // queries are at . -static bool DecodeStepT(const ModelConfig& config, - const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const size_t query_idx_start, const KVCaches& kv_caches, - const QueriesPos& queries_prefix_end, - const hwy::Divisor div_seq_len, const size_t vocab_size, - const SampleFunc& sample_token, - Activations& activations, TokenStreamer& token_streamer, - std::vector& gen_tokens, TimingInfo& timing_info, - const QueriesMutablePos& queries_mutable_pos) { +static bool DecodeStepT( + const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, const size_t query_idx_start, + const QueriesPromptTokens& queries_prompt, + const QueriesMutablePos& queries_mutable_pos, + const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len, + const size_t vocab_size, const SampleFunc& sample_token, + Activations& activations, const KVCaches& kv_caches, + TokenStreamer& token_streamer, std::vector& gen_tokens, + TimingInfo& timing_info, MatMulEnv& env) { const size_t num_queries = queries_prompt.size(); // Decode generates one token per query and increments // queries_mutable_pos. Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, - queries_prefix_end, config, weights, activations, div_seq_len, - kv_caches, runtime_config.layers_output, + queries_prefix_end, div_seq_len, config, weights, activations, + kv_caches, env, runtime_config.layers_output, runtime_config.activations_observer); // queries_pos are incremented by Transformer. @@ -407,7 +405,7 @@ static bool DecodeStepT(const ModelConfig& config, PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. CallMatMul(activations.x, weights.embedder_input_embedding, - /*add=*/nullptr, *activations.env, activations.logits); + /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { @@ -468,14 +466,12 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { // `StreamFunc` gets the global query index, not relative to the batch. // // `kv_caches` is for the batch, size must match `queries_prompt`. -static void GenerateT(const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos_in, - const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const KVCaches& kv_caches, - TimingInfo& timing_info) { +static void GenerateT( + const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, const size_t query_idx_start, + const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in, + const QueriesPos& queries_prefix_end, Activations& activations, + const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) { HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); // Griffin assumes that the recurrent block cache is zero-initialized. @@ -512,9 +508,9 @@ static void GenerateT(const ModelConfig& config, // token is the first input token for generation. timing_info.prefill_start = hwy::platform::Now(); // Note that Prefill calls activations.SetBatchSize, so we reset it below. - Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, - query_idx_start, config, weights, activations, runtime_config, - div_seq_len, kv_caches); + Prefill(query_idx_start, queries_prompt, queries_mutable_pos, + queries_prefix_end, div_seq_len, config, runtime_config, weights, + activations, kv_caches, env); // Compute the number of tokens that were prefilled and notify timing_info. size_t prefilled_tokens = 0; for (size_t qi = 0; qi < num_queries; ++qi) { @@ -543,11 +539,12 @@ static void GenerateT(const ModelConfig& config, const size_t vocab_size = config.vocab_size; timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - bool all_queries_eos = DecodeStepT( - config, weights, runtime_config, queries_prompt, query_idx_start, - kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, - activations, token_streamer, gen_tokens, timing_info, - queries_mutable_pos); + bool all_queries_eos = + DecodeStepT(config, weights, runtime_config, query_idx_start, + queries_prompt, queries_mutable_pos, queries_prefix_end, + div_seq_len, vocab_size, sample_token, activations, + kv_caches, token_streamer, gen_tokens, timing_info, env); + if (all_queries_eos) break; } // foreach token to generate timing_info.NotifyGenerateDone(); @@ -557,7 +554,7 @@ static void GenerateT(const ModelConfig& config, void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, MatMulEnv* env, + KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; @@ -565,16 +562,16 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, const size_t max_batch_size = HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); // TODO: move into Gemma? - Activations activations(config, max_batch_size, env); + Activations activations(config, max_batch_size, env.row_ptrs); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries); const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); const KVCaches kv_caches{&kv_cache, kNumQueries}; - GenerateT(config, weights, activations, runtime_config, queries_prompt, - queries_pos, queries_prefix_end, qbatch_start, kv_caches, - timing_info); + GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt, + queries_pos, queries_prefix_end, activations, kv_caches, + timing_info, env); } void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, @@ -582,7 +579,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, MatMulEnv* env, + const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); @@ -591,7 +588,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, const size_t max_batch_size = HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(config, max_batch_size, env); + Activations activations(config, max_batch_size, env.row_ptrs); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { @@ -604,9 +601,9 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(config, weights, activations, runtime_config, qbatch_prompts, - qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, - timing_info); + GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts, + qbatch_pos, qbatch_prefix_end, activations, qbatch_kv, + timing_info, env); } } @@ -614,7 +611,7 @@ void GenerateImageTokensT(const ModelConfig& config, const ModelWeightsPtrs& weights, const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens, - MatMulEnv* env) { + MatMulEnv& env) { if (config.vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } @@ -622,10 +619,10 @@ void GenerateImageTokensT(const ModelConfig& config, ModelConfig vit_config = GetVitConfig(config); prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, vit_config.seq_len, env); + Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, - prefill_activations); + prefill_activations, env); } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -677,7 +674,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_, runtime_config, prompt, pos, prefix_end, - kv_cache, &env_, timing_info); + kv_cache, env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -701,7 +698,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, HWY_DYNAMIC_DISPATCH(GenerateBatchT)( model_.Config(), weights_, runtime_config, queries_prompt, queries_pos, - mutable_queries_prefix_end, kv_caches, &env_, timing_info); + mutable_queries_prefix_end, kv_caches, env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -712,7 +709,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( - model_.Config(), weights_, runtime_config, image, image_tokens, &env_); + model_.Config(), weights_, runtime_config, image, image_tokens, env_); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/griffin.cc b/gemma/griffin.cc index 46606b1..59faca4 100644 --- a/gemma/griffin.cc +++ b/gemma/griffin.cc @@ -48,9 +48,9 @@ namespace HWY_NAMESPACE { void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, size_t griffin_layer, Activations& activations, const LayerWeightsPtrs* layer_weights, - const KVCaches& kv_caches) { + const KVCaches& kv_caches, MatMulEnv& env) { PROFILER_ZONE("Gen.Griffin"); - hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); + hwy::ThreadPool& pool = env.ctx.pools.Pool(0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D df; @@ -183,8 +183,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, // Final linear layer. CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, - layer_weights->griffin.linear_out_biases.PackedScale1(), - *activations.env, activations.att_sums); + layer_weights->griffin.linear_out_biases.PackedScale1(), env, + activations.att_sums); } // GriffinRecurrent // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/griffin.h b/gemma/griffin.h index 77011a3..fea514c 100644 --- a/gemma/griffin.h +++ b/gemma/griffin.h @@ -31,7 +31,7 @@ namespace gcpp { void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \ size_t griffin_layer, Activations& activations, \ const LayerWeightsPtrs* layer_weights, \ - const KVCaches& kv_caches); \ + const KVCaches& kv_caches, MatMulEnv& env); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/vit.cc b/gemma/vit.cc index f9c50ff..14d9fd9 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -60,7 +60,7 @@ class VitAttention { HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w, - layer_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, qkv); + layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); } // TODO(philculliton): transition fully to MatMul. @@ -100,7 +100,7 @@ class VitAttention { }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - CallMatMul(Q, K, nullptr, *activations_.env, C); + CallMatMul(Q, K, nullptr, env_, C); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); @@ -166,18 +166,19 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, - *activations_.env, activations_.att_sums); + CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_, + activations_.att_sums); } public: VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations, - const LayerWeightsPtrs& layer) + const LayerWeightsPtrs& layer, MatMulEnv& env) : num_tokens_(num_tokens), activations_(activations), layer_(layer), layer_config_(layer.layer_config), - pool_(activations.env->ctx.pools.Pool(0)) {} + env_(env), + pool_(env_.ctx.pools.Pool(0)) {} HWY_INLINE void operator()() { ComputeQKV(); @@ -194,12 +195,14 @@ class VitAttention { Activations& activations_; const LayerWeightsPtrs& layer_; const LayerConfig& layer_config_; + MatMulEnv& env_; hwy::ThreadPool& pool_; }; // Same as FFWNoVit, but with different layer members and no second // gating matrix. -void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) { +void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, + MatMulEnv& env) { PROFILER_ZONE("Gen.FFW.ViT"); const LayerConfig& layer_config = layer.layer_config; @@ -209,15 +212,15 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) { add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr; // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, - *activations.env, activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, env, + activations.C1); // Activation (Gelu), store in C1. ActivationBatched(layer_config.activation, activations.C1); // Hidden layer -> output layer. - CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, - *activations.env, activations.ffw_out); + CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env, + activations.ffw_out); } // Vit transformer layer. Some comments below refer to the Vit implementation in @@ -227,7 +230,7 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) { // try merging this with TransformerLayer. void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations) { + Activations& activations, MatMulEnv& env) { const size_t model_dim = activations.weights_config.model_dim; auto type = layer.layer_config.type; HWY_DASSERT(type == LayerAttentionType::kVit); @@ -245,7 +248,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y ~ att_sums - VitAttention(num_tokens, layer_idx, activations, layer)(); + VitAttention(num_tokens, layer_idx, activations, layer, env)(); // x = out["+sa"] = x + y AddFromBatched(activations.att_sums, x); @@ -257,7 +260,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, // y = out["mlp"] = MlpBlock(...)(y) // y ~ ffw_out - FFWVit(activations, layer); + FFWVit(layer, activations, env); // x = out["+mlp"] = x + y AddFromBatched(activations.ffw_out, x); @@ -268,7 +271,8 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, static HWY_NOINLINE void EmbedImagePatches(const Image& image, const ModelConfig& model_config, const ModelWeightsPtrs& weights, - Activations& activations) { + Activations& activations, + MatMulEnv& env) { const size_t model_dim = model_config.vit_config.model_dim; const size_t patch_width = model_config.vit_config.patch_width; const size_t seq_len = model_config.vit_config.seq_len; @@ -287,8 +291,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, image.GetPatch(i, image_patches.Row(i)); } CallMatMul(image_patches, weights.vit_img_embedding_kernel, - weights.vit_img_embedding_bias.PackedScale1(), *activations.env, - activations.x); + weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); // Add position embeddings. CallUpcastedActivation(&weights.vit_img_pos_embedding, [&](const auto* weights_t) { @@ -300,18 +303,19 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, void PrefillVit(const ModelConfig& model_config, const ModelWeightsPtrs& weights, const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, Activations& activations) { + ImageTokens& image_tokens, Activations& activations, + MatMulEnv& env) { PROFILER_ZONE("Gen.PrefillVit"); const size_t num_tokens = model_config.vit_config.seq_len; const size_t vit_model_dim = model_config.vit_config.model_dim; HWY_ASSERT(num_tokens == activations.x.Rows()); // Embed the image patches. - EmbedImagePatches(image, model_config, weights, activations); + EmbedImagePatches(image, model_config, weights, activations, env); // Go through all layers. for (size_t layer_idx = 0; layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) { VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx), - activations); + activations, env); } // Final Layernorm. LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, @@ -329,8 +333,7 @@ void PrefillVit(const ModelConfig& model_config, // Apply head embedding into image_tokens of size of the LLM kModelDim. CallMatMul(activations.x, weights.vit_img_head_kernel, - weights.vit_img_head_bias.PackedScale1(), *activations.env, - image_tokens); + weights.vit_img_head_bias.PackedScale1(), env, image_tokens); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/vit.h b/gemma/vit.h index 085081d..34d2307 100644 --- a/gemma/vit.h +++ b/gemma/vit.h @@ -28,12 +28,14 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - void FFWVit(Activations& activations, const LayerWeightsPtrs& layer); \ + void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \ + MatMulEnv& env); \ \ void PrefillVit(const ModelConfig& model_config, \ const ModelWeightsPtrs& weights, \ const RuntimeConfig& runtime_config, const Image& image, \ - ImageTokens& image_tokens, Activations& activations); \ + ImageTokens& image_tokens, Activations& activations, \ + MatMulEnv& env); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/dot_test.cc b/ops/dot_test.cc index e089e64..f210de5 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -756,7 +756,7 @@ class DotStats { // Compensated and Double are very accurate. ASSERT_LESS(kCompensated, s_muls[kCompensated].Min(), 1.0f + 2E-6f); - ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 2E-5f); + ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 1E-4f); ASSERT_LESS(kDouble, s_muls[kDouble].Min(), 1.0f + 2E-6f); ASSERT_LESS(kDouble, s_muls[kDouble].Max(), 1.0f + 2E-5f); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 9454b66..563be4c 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1324,9 +1324,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, } HWY_DASSERT(C.HasPtr()); for (size_t r = 0; r < C.Rows(); ++r) { - env.storage.OutRow(r) = reinterpret_cast(C.Row(r)); + env.row_ptrs[0][r] = reinterpret_cast(C.Row(r)); } - C_rows = CRows(&env.storage.OutRow(0)); + C_rows = CRows(env.row_ptrs[0].get()); } const Allocator& allocator = env.ctx.allocator; diff --git a/ops/matmul.cc b/ops/matmul.cc index bc662e8..21a1b91 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -427,6 +427,8 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); + + row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); } void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { diff --git a/ops/matmul.h b/ops/matmul.h index 4e323d4..30ed703 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -231,10 +231,9 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). MMStorage(const Allocator& allocator, MMParallel& parallel) - : out_rows(hwy::AllocateAligned(kMaxM)), - // Per-worker copies of `partial` would be wasteful. We instead allocate - // one instance of the maximum matrix extents because threads write at - // false-sharing-free granularity. + : // Per-worker copies of `partial` would be wasteful. We instead + // allocate one instance of the maximum matrix extents because threads + // write at false-sharing-free granularity. partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), MatPadding::kOdd), // Same stride independent of the actual C.Cols() so we can pre-bind. @@ -260,8 +259,6 @@ class MMStorage { BindC(partial_storage_, parallel); } - uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; } - // Returns per-package matrix view. RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); @@ -273,11 +270,6 @@ class MMStorage { RowPtrD Partial() const { return partial_; } private: - // Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a - // constant stride, but GemmaAttention::ComputeQKV writes to differing KV - // positions per query / output row. `kMaxM` elements are too large for the - // stack, hence dynamic allocation. - hwy::AlignedFreeUniquePtr out_rows; std::unique_ptr> pkg_A_[MMParallel::kMaxPackages]; MatStorageT partial_storage_; RowPtrD partial_; @@ -667,9 +659,13 @@ struct MatMulEnv { MMKeys keys; std::vector per_key; - // Pass to MatPtr::AllocateAndAttachRowPtrs. - // Per-tensor allocations to make it likelier that asan detects bugs such as - // use after free, overrun, and dangling references. + // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. + // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV + // writes to differing KV positions per query / output row. + // The first entry is sufficient for any MatMul, but also potentially + // overwritten by each MatMul. Subsequent entries are precomputed for tensors + // and not overwritten. Per-tensor allocations make it likelier that asan + // detects bugs such as use after free, overrun, and dangling references. std::vector> row_ptrs; };