Further cleanup: separate MatMulEnv arg

move row_ptrs into MatMulEnv
Consistent arg order: layer, activations, kv_cache, env

PiperOrigin-RevId: 767886386
This commit is contained in:
Jan Wassenberg 2025-06-05 20:47:57 -07:00 committed by Copybara-Service
parent e774ddbaaa
commit 6ee628ba38
18 changed files with 145 additions and 227 deletions

View File

@ -242,13 +242,6 @@ cc_test(
], ],
) )
cc_library(
name = "common",
srcs = ["gemma/common.cc"],
hdrs = ["gemma/common.h"],
deps = [":configs"],
)
# For building all tests in one command, so we can test several. # For building all tests in one command, so we can test several.
test_suite( test_suite(
name = "ops_tests", name = "ops_tests",
@ -566,11 +559,10 @@ cc_library(
}, },
deps = [ deps = [
":benchmark_helper", ":benchmark_helper",
":common",
":gemma_args", ":gemma_args",
":gemma_lib", ":gemma_lib",
":kv_cache", ":kv_cache",
":ops", ":matmul",
":threading", ":threading",
":threading_context", ":threading_context",
":tokenizer", ":tokenizer",

View File

@ -55,8 +55,6 @@ set(SOURCES
gemma/activations.h gemma/activations.h
gemma/attention.cc gemma/attention.cc
gemma/attention.h gemma/attention.h
gemma/common.cc
gemma/common.h
gemma/configs.cc gemma/configs.cc
gemma/configs.h gemma/configs.h
gemma/gemma_args.h gemma/gemma_args.h

View File

@ -41,13 +41,13 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
} }
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env) Activations(const ModelConfig& config, size_t batch_size,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config), : weights_config(config),
layer_config(config.layer_configs[0]), layer_config(config.layer_configs[0]),
seq_len(config.seq_len), seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()), cache_pos_size(config.CachePosSize()),
is_griffin(config.model == Model::GRIFFIN_2B), is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
x("x", Extents2D(batch_size, config.model_dim), pad_), x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
@ -96,19 +96,19 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)), 1000000.0)),
env(env) { query_scale(ChooseQueryScale(config)) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(env->row_ptrs); x.AllocateAndAttachRowPtrs(row_ptrs);
q.AllocateAndAttachRowPtrs(env->row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(env->row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(env->row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(env->row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(env->row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs); ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
// Note that BindC on any MatMul output considerably slows down Prefill. // Note that BindC on any MatMul output considerably slows down Prefill.
} }
@ -141,7 +141,6 @@ struct Activations {
size_t seq_len; size_t seq_len;
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
bool is_griffin = false; bool is_griffin = false;
float query_scale;
const Extents2D none_ = Extents2D(); const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd; const MatPadding pad_ = MatPadding::kOdd;
@ -172,7 +171,7 @@ struct Activations {
MatStorageT<float> inv_timescale; MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global; MatStorageT<float> inv_timescale_global;
MatMulEnv* env; float query_scale;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -222,7 +222,7 @@ static HWY_INLINE void ComputeQKV(
size_t num_tokens, const QueriesPos& queries_pos, size_t num_tokens, const QueriesPos& queries_pos,
const hwy::Divisor& div_seq_len, const size_t layer_idx, const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations, const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, const int flags, NestedPools& pools) { const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
const size_t num_queries = queries_pos.size(); const size_t num_queries = queries_pos.size();
const size_t num_interleaved = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
@ -235,7 +235,7 @@ static HWY_INLINE void ComputeQKV(
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1, CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
/*add=*/nullptr, *activations.env, activations.q); /*add=*/nullptr, env, activations.q);
// Set up MatMul row pointers for writing to KV, which consists of // Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
@ -250,17 +250,16 @@ static HWY_INLINE void ComputeQKV(
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
const size_t kv_offset = const size_t kv_offset =
cache_pos * cache_pos_size + layer_idx * cache_layer_size; cache_pos * cache_pos_size + layer_idx * cache_layer_size;
activations.env->storage.OutRow(interleaved_idx) = env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
reinterpret_cast<uint8_t*>(kv_caches[query_idx].kv_cache.get() + kv_caches[query_idx].kv_cache.get() + kv_offset);
kv_offset);
} }
kv_rows.AttachRowPtrs(&activations.env->storage.OutRow(0)); kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, *activations.env, kv_rows); /*add=*/nullptr, env, kv_rows);
// Apply positional encodings for K. // Apply positional encodings for K.
// TODO: 2D parallelism to use more threads. // TODO: 2D parallelism to use more threads.
pools.Pool(0).Run( env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved, 0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
@ -289,7 +288,7 @@ static HWY_INLINE void ComputeQKV(
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`). // head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
Activations& activations) { Activations& activations, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads"); PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
@ -302,7 +301,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
const float* add = layer_config.softmax_attn_output_biases const float* add = layer_config.softmax_attn_output_biases
? layer.attention_output_biases.PackedScale1() ? layer.attention_output_biases.PackedScale1()
: nullptr; : nullptr;
CallMatMul(activations.att_out, layer.att_weights, add, *activations.env, CallMatMul(activations.att_out, layer.att_weights, add, env,
activations.att_sums); activations.att_sums);
} }
@ -312,7 +311,7 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos* queries_prefix_end, const QueriesPos* queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx, const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations, const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, int flags) { const KVCaches& kv_caches, MatMulEnv& env, int flags) {
const size_t num_queries = queries_pos.size(); const size_t num_queries = queries_pos.size();
HWY_DASSERT(num_queries <= kv_caches.size()); HWY_DASSERT(num_queries <= kv_caches.size());
@ -331,13 +330,12 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
queries_prefix_end = &queries_prefix_end_span; queries_prefix_end = &queries_prefix_end_span;
} }
NestedPools& pools = activations.env->ctx.pools;
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer, ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
activations, kv_caches, flags, pools); activations, kv_caches, flags, env);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches, div_seq_len, layer_idx, layer, activations, kv_caches,
pools); env.ctx.pools);
SumHeads(layer, activations); SumHeads(layer, activations, env);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -47,7 +47,7 @@ namespace gcpp {
const QueriesPos* queries_prefix_end, \ const QueriesPos* queries_prefix_end, \
const hwy::Divisor& div_seq_len, const size_t layer_idx, \ const hwy::Divisor& div_seq_len, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \ const LayerWeightsPtrs& layer, Activations& activations, \
const KVCaches& kv_caches, int flags); \ const KVCaches& kv_caches, MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -29,7 +29,6 @@
#include <stdio.h> #include <stdio.h>
#endif #endif
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "ops/matmul.h" // MatMulEnv #include "ops/matmul.h" // MatMulEnv

View File

@ -1,38 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/common.h"
#include <stddef.h>
#include <string>
#include "gemma/configs.h"
namespace gcpp {
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (config.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
}
} // namespace gcpp

View File

@ -1,33 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include <stddef.h>
#include <string>
#include "gemma/configs.h" // IWYU pragma: export
namespace gcpp {
// Wraps the given prompt using the expected control tokens for IT models.
// DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine.
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

View File

@ -18,7 +18,10 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "gemma/activations.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/mat.h" #include "util/mat.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -103,8 +106,8 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
} }
} }
static inline void FFWNoVit(Activations& activations, static inline void FFWNoVit(const LayerWeightsPtrs& layer,
const LayerWeightsPtrs& layer) { Activations& activations, MatMulEnv& env) {
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW");
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
@ -117,16 +120,16 @@ static inline void FFWNoVit(Activations& activations,
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
// Compute the hidden layer activations. // Compute the hidden layer activations.
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
*activations.env, activations.C1); activations.C1);
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
*activations.env, activations.C2); activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_config.activation, activations.C1, &activations.C2); ActivationBatched(layer_config.activation, activations.C1, &activations.C2);
// Hidden layer -> output layer. // Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, output_bias, *activations.env, CallMatMul(activations.C1, layer.linear_w, output_bias, env,
activations.ffw_out); activations.ffw_out);
} }

View File

@ -68,10 +68,10 @@ void Attention(LayerAttentionType type, size_t num_tokens,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx, const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations, const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches) { const KVCaches& kv_caches, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len, GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
layer_idx, layer, activations, kv_caches, layer_idx, layer, activations, kv_caches, env,
/*flags=*/0); /*flags=*/0);
} else { } else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
@ -80,7 +80,7 @@ void Attention(LayerAttentionType type, size_t num_tokens,
const size_t griffin_layer = const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
&layer, kv_caches); &layer, kv_caches, env);
} }
} }
@ -88,14 +88,14 @@ static HWY_NOINLINE void TransformerLayer(
const size_t num_tokens, const QueriesPos& queries_pos, const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches) { Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale, RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out); activations.pre_att_rms_out);
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches); div_seq_len, layer_idx, layer, activations, kv_caches, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums); activations.att_sums);
@ -107,9 +107,9 @@ static HWY_NOINLINE void TransformerLayer(
activations.pre_ffw_rms_out); activations.pre_ffw_rms_out);
if (layer_config.type == LayerAttentionType::kVit) { if (layer_config.type == LayerAttentionType::kVit) {
FFWVit(activations, layer); FFWVit(layer, activations, env);
} else { } else {
FFWNoVit(activations, layer); FFWNoVit(layer, activations, env);
} }
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
@ -187,12 +187,11 @@ using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time. // Populates KV cache for batches of tokens from one query at a time.
static HWY_NOINLINE void Prefill( static HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt, const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const ModelConfig& config, const hwy::Divisor& div_seq_len, const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations, const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill"); PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_DASSERT(queries_pos.size() == num_queries); HWY_DASSERT(queries_pos.size() == num_queries);
@ -263,7 +262,7 @@ static HWY_NOINLINE void Prefill(
++layer_idx) { ++layer_idx) {
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx), div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, single_kv_cache); activations, single_kv_cache, env);
} }
// NOTE: we unconditionally call StreamToken, even if EOS. // NOTE: we unconditionally call StreamToken, even if EOS.
@ -298,9 +297,9 @@ static HWY_NOINLINE void Prefill(
// from each query, and `queries_pos` are their position in the sequence. // from each query, and `queries_pos` are their position in the sequence.
static HWY_NOINLINE void Transformer( static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const ModelConfig& config, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const ModelWeightsPtrs& weights, Activations& activations, const ModelConfig& config, const ModelWeightsPtrs& weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
const LayersOutputFunc& layers_output, const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) { const ActivationsObserverFunc& activations_observer) {
const size_t num_queries = queries_token.size(); const size_t num_queries = queries_token.size();
@ -323,7 +322,7 @@ static HWY_NOINLINE void Transformer(
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx), div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, kv_caches); activations, kv_caches, env);
if (activations_observer) { if (activations_observer) {
activations_observer(queries_pos, layer_idx, activations); activations_observer(queries_pos, layer_idx, activations);
@ -381,23 +380,22 @@ class TokenStreamer {
// Runs one decode step for all the queries in the batch. Returns true if all // Runs one decode step for all the queries in the batch. Returns true if all
// queries are at <end_of_sentence>. // queries are at <end_of_sentence>.
static bool DecodeStepT(const ModelConfig& config, static bool DecodeStepT(
const ModelWeightsPtrs& weights, const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config, const size_t query_idx_start,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const size_t query_idx_start, const KVCaches& kv_caches, const QueriesMutablePos& queries_mutable_pos,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
const hwy::Divisor div_seq_len, const size_t vocab_size, const size_t vocab_size, const SampleFunc& sample_token,
const SampleFunc& sample_token, Activations& activations, const KVCaches& kv_caches,
Activations& activations, TokenStreamer& token_streamer, TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
std::vector<int>& gen_tokens, TimingInfo& timing_info, TimingInfo& timing_info, MatMulEnv& env) {
const QueriesMutablePos& queries_mutable_pos) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
// Decode generates one token per query and increments // Decode generates one token per query and increments
// queries_mutable_pos. // queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
queries_prefix_end, config, weights, activations, div_seq_len, queries_prefix_end, div_seq_len, config, weights, activations,
kv_caches, runtime_config.layers_output, kv_caches, env, runtime_config.layers_output,
runtime_config.activations_observer); runtime_config.activations_observer);
// queries_pos are incremented by Transformer. // queries_pos are incremented by Transformer.
@ -407,7 +405,7 @@ static bool DecodeStepT(const ModelConfig& config,
PROFILER_ZONE("Gen.EmbeddingMatmul"); PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations. // Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding, CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env, activations.logits); /*add=*/nullptr, env, activations.logits);
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
@ -468,14 +466,12 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
// `StreamFunc` gets the global query index, not relative to the batch. // `StreamFunc` gets the global query index, not relative to the batch.
// //
// `kv_caches` is for the batch, size must match `queries_prompt`. // `kv_caches` is for the batch, size must match `queries_prompt`.
static void GenerateT(const ModelConfig& config, static void GenerateT(
const ModelWeightsPtrs& weights, Activations& activations, const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config, const size_t query_idx_start,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in,
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end, Activations& activations,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) {
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized. // Griffin assumes that the recurrent block cache is zero-initialized.
@ -512,9 +508,9 @@ static void GenerateT(const ModelConfig& config,
// token is the first input token for generation. // token is the first input token for generation.
timing_info.prefill_start = hwy::platform::Now(); timing_info.prefill_start = hwy::platform::Now();
// Note that Prefill calls activations.SetBatchSize, so we reset it below. // Note that Prefill calls activations.SetBatchSize, so we reset it below.
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, Prefill(query_idx_start, queries_prompt, queries_mutable_pos,
query_idx_start, config, weights, activations, runtime_config, queries_prefix_end, div_seq_len, config, runtime_config, weights,
div_seq_len, kv_caches); activations, kv_caches, env);
// Compute the number of tokens that were prefilled and notify timing_info. // Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0; size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < num_queries; ++qi) {
@ -543,11 +539,12 @@ static void GenerateT(const ModelConfig& config,
const size_t vocab_size = config.vocab_size; const size_t vocab_size = config.vocab_size;
timing_info.generate_start = hwy::platform::Now(); timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) { for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
bool all_queries_eos = DecodeStepT( bool all_queries_eos =
config, weights, runtime_config, queries_prompt, query_idx_start, DecodeStepT(config, weights, runtime_config, query_idx_start,
kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, queries_prompt, queries_mutable_pos, queries_prefix_end,
activations, token_streamer, gen_tokens, timing_info, div_seq_len, vocab_size, sample_token, activations,
queries_mutable_pos); kv_caches, token_streamer, gen_tokens, timing_info, env);
if (all_queries_eos) break; if (all_queries_eos) break;
} // foreach token to generate } // foreach token to generate
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
@ -557,7 +554,7 @@ static void GenerateT(const ModelConfig& config,
void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, MatMulEnv* env, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) { TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1; constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0; const size_t qbatch_start = 0;
@ -565,16 +562,16 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const size_t max_batch_size = const size_t max_batch_size =
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
// TODO: move into Gemma? // TODO: move into Gemma?
Activations activations(config, max_batch_size, env); Activations activations(config, max_batch_size, env.row_ptrs);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries}; const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT(config, weights, activations, runtime_config, queries_prompt, GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt,
queries_pos, queries_prefix_end, qbatch_start, kv_caches, queries_pos, queries_prefix_end, activations, kv_caches,
timing_info); timing_info, env);
} }
void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
@ -582,7 +579,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, MatMulEnv* env, const KVCaches& kv_caches, MatMulEnv& env,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(queries_pos.size() == num_queries);
@ -591,7 +588,7 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const size_t max_batch_size = const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, env); Activations activations(config, max_batch_size, env.row_ptrs);
for (size_t qbatch_start = 0; qbatch_start < num_queries; for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) { qbatch_start += max_qbatch_size) {
@ -604,9 +601,9 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size); qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT(config, weights, activations, runtime_config, qbatch_prompts, GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, qbatch_pos, qbatch_prefix_end, activations, qbatch_kv,
timing_info); timing_info, env);
} }
} }
@ -614,7 +611,7 @@ void GenerateImageTokensT(const ModelConfig& config,
const ModelWeightsPtrs& weights, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
MatMulEnv* env) { MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) { if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens."); HWY_ABORT("Model does not support generating image tokens.");
} }
@ -622,10 +619,10 @@ void GenerateImageTokensT(const ModelConfig& config,
ModelConfig vit_config = GetVitConfig(config); ModelConfig vit_config = GetVitConfig(config);
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env); Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations); prefill_activations, env);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
@ -677,7 +674,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_, HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
runtime_config, prompt, pos, prefix_end, runtime_config, prompt, pos, prefix_end,
kv_cache, &env_, timing_info); kv_cache, env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }
@ -701,7 +698,7 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
HWY_DYNAMIC_DISPATCH(GenerateBatchT)( HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos, model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
mutable_queries_prefix_end, kv_caches, &env_, timing_info); mutable_queries_prefix_end, kv_caches, env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }
@ -712,7 +709,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
model_.Config(), weights_, runtime_config, image, image_tokens, &env_); model_.Config(), weights_, runtime_config, image, image_tokens, env_);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }

View File

@ -48,9 +48,9 @@ namespace HWY_NAMESPACE {
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
size_t griffin_layer, Activations& activations, size_t griffin_layer, Activations& activations,
const LayerWeightsPtrs* layer_weights, const LayerWeightsPtrs* layer_weights,
const KVCaches& kv_caches) { const KVCaches& kv_caches, MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
const D df; const D df;
@ -183,8 +183,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
// Final linear layer. // Final linear layer.
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), layer_weights->griffin.linear_out_biases.PackedScale1(), env,
*activations.env, activations.att_sums); activations.att_sums);
} // GriffinRecurrent } // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -31,7 +31,7 @@ namespace gcpp {
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
size_t griffin_layer, Activations& activations, \ size_t griffin_layer, Activations& activations, \
const LayerWeightsPtrs* layer_weights, \ const LayerWeightsPtrs* layer_weights, \
const KVCaches& kv_caches); \ const KVCaches& kv_caches, MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -60,7 +60,7 @@ class VitAttention {
HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w, CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w,
layer_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, qkv); layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
} }
// TODO(philculliton): transition fully to MatMul. // TODO(philculliton): transition fully to MatMul.
@ -100,7 +100,7 @@ class VitAttention {
}); });
// this produces C, a (num_tokens_, seq_len) matrix of dot products // this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, *activations_.env, C); CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task); float* HWY_RESTRICT c = C.Row(task);
@ -166,18 +166,19 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_,
*activations_.env, activations_.att_sums); activations_.att_sums);
} }
public: public:
VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations, VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations,
const LayerWeightsPtrs& layer) const LayerWeightsPtrs& layer, MatMulEnv& env)
: num_tokens_(num_tokens), : num_tokens_(num_tokens),
activations_(activations), activations_(activations),
layer_(layer), layer_(layer),
layer_config_(layer.layer_config), layer_config_(layer.layer_config),
pool_(activations.env->ctx.pools.Pool(0)) {} env_(env),
pool_(env_.ctx.pools.Pool(0)) {}
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
ComputeQKV(); ComputeQKV();
@ -194,12 +195,14 @@ class VitAttention {
Activations& activations_; Activations& activations_;
const LayerWeightsPtrs& layer_; const LayerWeightsPtrs& layer_;
const LayerConfig& layer_config_; const LayerConfig& layer_config_;
MatMulEnv& env_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
}; };
// Same as FFWNoVit, but with different layer members and no second // Same as FFWNoVit, but with different layer members and no second
// gating matrix. // gating matrix.
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) { void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.FFW.ViT"); PROFILER_ZONE("Gen.FFW.ViT");
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
@ -209,15 +212,15 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr; add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr;
// Compute the hidden layer activations. // Compute the hidden layer activations.
CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, env,
*activations.env, activations.C1); activations.C1);
// Activation (Gelu), store in C1. // Activation (Gelu), store in C1.
ActivationBatched(layer_config.activation, activations.C1); ActivationBatched(layer_config.activation, activations.C1);
// Hidden layer -> output layer. // Hidden layer -> output layer.
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
*activations.env, activations.ffw_out); activations.ffw_out);
} }
// Vit transformer layer. Some comments below refer to the Vit implementation in // Vit transformer layer. Some comments below refer to the Vit implementation in
@ -227,7 +230,7 @@ void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
// try merging this with TransformerLayer. // try merging this with TransformerLayer.
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations) { Activations& activations, MatMulEnv& env) {
const size_t model_dim = activations.weights_config.model_dim; const size_t model_dim = activations.weights_config.model_dim;
auto type = layer.layer_config.type; auto type = layer.layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit); HWY_DASSERT(type == LayerAttentionType::kVit);
@ -245,7 +248,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums // y ~ att_sums
VitAttention(num_tokens, layer_idx, activations, layer)(); VitAttention(num_tokens, layer_idx, activations, layer, env)();
// x = out["+sa"] = x + y // x = out["+sa"] = x + y
AddFromBatched(activations.att_sums, x); AddFromBatched(activations.att_sums, x);
@ -257,7 +260,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// y = out["mlp"] = MlpBlock(...)(y) // y = out["mlp"] = MlpBlock(...)(y)
// y ~ ffw_out // y ~ ffw_out
FFWVit(activations, layer); FFWVit(layer, activations, env);
// x = out["+mlp"] = x + y // x = out["+mlp"] = x + y
AddFromBatched(activations.ffw_out, x); AddFromBatched(activations.ffw_out, x);
@ -268,7 +271,8 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
static HWY_NOINLINE void EmbedImagePatches(const Image& image, static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelConfig& model_config, const ModelConfig& model_config,
const ModelWeightsPtrs& weights, const ModelWeightsPtrs& weights,
Activations& activations) { Activations& activations,
MatMulEnv& env) {
const size_t model_dim = model_config.vit_config.model_dim; const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width; const size_t patch_width = model_config.vit_config.patch_width;
const size_t seq_len = model_config.vit_config.seq_len; const size_t seq_len = model_config.vit_config.seq_len;
@ -287,8 +291,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
image.GetPatch(i, image_patches.Row(i)); image.GetPatch(i, image_patches.Row(i));
} }
CallMatMul(image_patches, weights.vit_img_embedding_kernel, CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), *activations.env, weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
activations.x);
// Add position embeddings. // Add position embeddings.
CallUpcastedActivation(&weights.vit_img_pos_embedding, CallUpcastedActivation(&weights.vit_img_pos_embedding,
[&](const auto* weights_t) { [&](const auto* weights_t) {
@ -300,18 +303,19 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
void PrefillVit(const ModelConfig& model_config, void PrefillVit(const ModelConfig& model_config,
const ModelWeightsPtrs& weights, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, Activations& activations) { ImageTokens& image_tokens, Activations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.PrefillVit"); PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = model_config.vit_config.seq_len; const size_t num_tokens = model_config.vit_config.seq_len;
const size_t vit_model_dim = model_config.vit_config.model_dim; const size_t vit_model_dim = model_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.Rows()); HWY_ASSERT(num_tokens == activations.x.Rows());
// Embed the image patches. // Embed the image patches.
EmbedImagePatches(image, model_config, weights, activations); EmbedImagePatches(image, model_config, weights, activations, env);
// Go through all layers. // Go through all layers.
for (size_t layer_idx = 0; for (size_t layer_idx = 0;
layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) { layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) {
VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx), VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx),
activations); activations, env);
} }
// Final Layernorm. // Final Layernorm.
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
@ -329,8 +333,7 @@ void PrefillVit(const ModelConfig& model_config,
// Apply head embedding into image_tokens of size of the LLM kModelDim. // Apply head embedding into image_tokens of size of the LLM kModelDim.
CallMatMul(activations.x, weights.vit_img_head_kernel, CallMatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env, weights.vit_img_head_bias.PackedScale1(), env, image_tokens);
image_tokens);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -28,12 +28,14 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ #define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer); \ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
MatMulEnv& env); \
\ \
void PrefillVit(const ModelConfig& model_config, \ void PrefillVit(const ModelConfig& model_config, \
const ModelWeightsPtrs& weights, \ const ModelWeightsPtrs& weights, \
const RuntimeConfig& runtime_config, const Image& image, \ const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, Activations& activations); \ ImageTokens& image_tokens, Activations& activations, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -756,7 +756,7 @@ class DotStats {
// Compensated and Double are very accurate. // Compensated and Double are very accurate.
ASSERT_LESS(kCompensated, s_muls[kCompensated].Min(), 1.0f + 2E-6f); ASSERT_LESS(kCompensated, s_muls[kCompensated].Min(), 1.0f + 2E-6f);
ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 2E-5f); ASSERT_LESS(kCompensated, s_muls[kCompensated].Max(), 1.0f + 1E-4f);
ASSERT_LESS(kDouble, s_muls[kDouble].Min(), 1.0f + 2E-6f); ASSERT_LESS(kDouble, s_muls[kDouble].Min(), 1.0f + 2E-6f);
ASSERT_LESS(kDouble, s_muls[kDouble].Max(), 1.0f + 2E-5f); ASSERT_LESS(kDouble, s_muls[kDouble].Max(), 1.0f + 2E-5f);

View File

@ -1324,9 +1324,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
} }
HWY_DASSERT(C.HasPtr()); HWY_DASSERT(C.HasPtr());
for (size_t r = 0; r < C.Rows(); ++r) { for (size_t r = 0; r < C.Rows(); ++r) {
env.storage.OutRow(r) = reinterpret_cast<uint8_t*>(C.Row(r)); env.row_ptrs[0][r] = reinterpret_cast<uint8_t*>(C.Row(r));
} }
C_rows = CRows<TC>(&env.storage.OutRow(0)); C_rows = CRows<TC>(env.row_ptrs[0].get());
} }
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;

View File

@ -427,6 +427,8 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
char cpu100[100]; char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM));
} }
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {

View File

@ -231,10 +231,9 @@ class MMStorage {
// Internally threaded; must not be called concurrently with the same // Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`). // `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel) MMStorage(const Allocator& allocator, MMParallel& parallel)
: out_rows(hwy::AllocateAligned<uint8_t*>(kMaxM)), : // Per-worker copies of `partial` would be wasteful. We instead
// Per-worker copies of `partial` would be wasteful. We instead allocate // allocate one instance of the maximum matrix extents because threads
// one instance of the maximum matrix extents because threads write at // write at false-sharing-free granularity.
// false-sharing-free granularity.
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
MatPadding::kOdd), MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind. // Same stride independent of the actual C.Cols() so we can pre-bind.
@ -260,8 +259,6 @@ class MMStorage {
BindC(partial_storage_, parallel); BindC(partial_storage_, parallel);
} }
uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; }
// Returns per-package matrix view. // Returns per-package matrix view.
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxM);
@ -273,11 +270,6 @@ class MMStorage {
RowPtrD Partial() const { return partial_; } RowPtrD Partial() const { return partial_; }
private: private:
// Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a
// constant stride, but GemmaAttention::ComputeQKV writes to differing KV
// positions per query / output row. `kMaxM` elements are too large for the
// stack, hence dynamic allocation.
hwy::AlignedFreeUniquePtr<uint8_t*[]> out_rows;
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages]; std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
MatStorageT<double> partial_storage_; MatStorageT<double> partial_storage_;
RowPtrD partial_; RowPtrD partial_;
@ -667,9 +659,13 @@ struct MatMulEnv {
MMKeys keys; MMKeys keys;
std::vector<MMPerKey> per_key; std::vector<MMPerKey> per_key;
// Pass to MatPtr::AllocateAndAttachRowPtrs. // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Per-tensor allocations to make it likelier that asan detects bugs such as // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
// use after free, overrun, and dangling references. // writes to differing KV positions per query / output row.
// The first entry is sufficient for any MatMul, but also potentially
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
// and not overwritten. Per-tensor allocations make it likelier that asan
// detects bugs such as use after free, overrun, and dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
}; };