From 3a266c662cad2f1eb69503d72e110d5e61777e48 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 5 Jun 2025 05:36:08 -0700 Subject: [PATCH] Split gemma-inl into separate source files weights, mat: zero-initialize padding, required since the MatMul "avoid B decompress" optimization. PiperOrigin-RevId: 767562313 --- BUILD.bazel | 81 +- CMakeLists.txt | 6 + examples/simplified_gemma/BUILD.bazel | 2 +- gemma/activations.h | 13 + gemma/attention.cc | 346 ++++++ gemma/attention.h | 63 ++ gemma/common.cc | 30 - gemma/common.h | 9 - gemma/gemma-inl.h | 1412 +------------------------ gemma/gemma.cc | 580 ++++++++++ gemma/gemma.h | 3 - gemma/griffin.cc | 193 ++++ gemma/griffin.h | 47 + gemma/vit.cc | 339 ++++++ gemma/vit.h | 49 + gemma/weights.cc | 4 +- gemma/weights.h | 2 + ops/matmul-inl.h | 4 + ops/ops-inl.h | 10 + ops/ops_test.cc | 2 +- util/mat.cc | 7 +- 21 files changed, 1736 insertions(+), 1466 deletions(-) create mode 100644 gemma/attention.cc create mode 100644 gemma/attention.h create mode 100644 gemma/griffin.cc create mode 100644 gemma/griffin.h create mode 100644 gemma/vit.cc create mode 100644 gemma/vit.h diff --git a/BUILD.bazel b/BUILD.bazel index 2792314..6ca9003 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -216,8 +216,8 @@ cc_library( ":configs", ":gemma_args", ":mat", + ":matmul", ":model_store", - ":ops", ":tensor_info", ":threading_context", "//compression:compress", @@ -246,11 +246,7 @@ cc_library( name = "common", srcs = ["gemma/common.cc"], hdrs = ["gemma/common.h"], - deps = [ - ":basics", - ":configs", - "@highway//:hwy", # base.h - ], + deps = [":configs"], ) # For building all tests in one command, so we can test several. @@ -260,42 +256,25 @@ test_suite( ) cc_library( - name = "ops", - srcs = [ - "ops/matmul.cc", - ], - hdrs = [ - "ops/matmul.h", - "ops/ops.h", - ], - textual_hdrs = [ - "ops/dot-inl.h", - "ops/sum-inl.h", - "ops/fp_arith-inl.h", - "ops/matmul-inl.h", - "ops/matvec-inl.h", - "ops/ops-inl.h", - ], + name = "matmul", + srcs = ["ops/matmul.cc"], + hdrs = ["ops/matmul.h"], + textual_hdrs = ["ops/matmul-inl.h"], deps = [ ":allocator", ":basics", ":mat", ":threading_context", "//compression:compress", - "@highway//:algo", "@highway//:bit_set", "@highway//:hwy", - "@highway//:math", - "@highway//:matvec", "@highway//:nanobenchmark", "@highway//:profiler", - "@highway//:thread_pool", - "@highway//hwy/contrib/sort:vqsort", ], ) cc_library( - name = "matmul", + name = "matmul_static", srcs = [ # single-file build time is ~30sec for msan, hence shard. "ops/matmul_static_bf16.cc", @@ -313,7 +292,7 @@ cc_library( deps = [ ":allocator", ":basics", - ":ops", + ":matmul", ":threading_context", "//compression:compress", "//compression:types", @@ -323,6 +302,34 @@ cc_library( ], ) +cc_library( + name = "ops", + hdrs = ["ops/ops.h"], + textual_hdrs = [ + "ops/dot-inl.h", + "ops/sum-inl.h", + "ops/fp_arith-inl.h", + "ops/matvec-inl.h", + "ops/ops-inl.h", + ], + deps = [ + ":allocator", + ":basics", + ":mat", + ":matmul", + ":matmul_static", + ":threading_context", + "//compression:compress", + "@highway//:algo", + "@highway//:hwy", + "@highway//:math", + "@highway//:matvec", + "@highway//:profiler", + "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", + ], +) + cc_test( name = "dot_test", size = "small", @@ -359,7 +366,7 @@ cc_test( deps = [ ":allocator", ":basics", - ":common", + ":gemma_lib", ":mat", ":ops", ":test_util", @@ -402,6 +409,7 @@ cc_test( ":basics", ":mat", ":matmul", + ":matmul_static", ":ops", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep @@ -458,7 +466,7 @@ cc_library( ":args", ":basics", ":mat", - ":ops", # matmul.h + ":matmul", "//io", "@highway//:hwy", ], @@ -467,11 +475,17 @@ cc_library( cc_library( name = "gemma_lib", srcs = [ + "gemma/attention.cc", "gemma/gemma.cc", + "gemma/griffin.cc", + "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", + "gemma/attention.h", "gemma/gemma.h", + "gemma/griffin.h", + "gemma/vit.h", ], exec_properties = { # Avoid linker OOMs when building with sanitizer instrumentation. @@ -479,12 +493,10 @@ cc_library( }, textual_hdrs = [ "gemma/gemma-inl.h", - # Placeholder for internal file2, do not remove, ], deps = [ ":allocator", ":basics", - ":common", ":configs", ":gemma_args", ":kv_cache", @@ -527,6 +539,7 @@ cc_library( ":cross_entropy", ":gemma_args", ":gemma_lib", + ":matmul", ":ops", ":threading_context", ":tokenizer", @@ -616,7 +629,7 @@ cc_binary( ":benchmark_helper", ":gemma_args", ":gemma_lib", - ":ops", + ":matmul", ":tokenizer", "//compression:types", "//paligemma:image", diff --git a/CMakeLists.txt b/CMakeLists.txt index c425cf1..e7dcfae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,8 @@ set(SOURCES evals/cross_entropy.cc evals/cross_entropy.h gemma/activations.h + gemma/attention.cc + gemma/attention.h gemma/common.cc gemma/common.h gemma/configs.cc @@ -61,6 +63,8 @@ set(SOURCES gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h + gemma/griffin.cc + gemma/griffin.h gemma/kv_cache.cc gemma/kv_cache.h gemma/model_store.cc @@ -69,6 +73,8 @@ set(SOURCES gemma/tensor_info.h gemma/tokenizer.cc gemma/tokenizer.h + gemma/vit.cc + gemma/vit.h gemma/weights.cc gemma/weights.h io/blob_store.cc diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel index 2678ada..98c0f5e 100644 --- a/examples/simplified_gemma/BUILD.bazel +++ b/examples/simplified_gemma/BUILD.bazel @@ -12,7 +12,7 @@ cc_library( deps = [ "//:gemma_args", "//:gemma_lib", - "//:ops", + "//:matmul", "//:threading_context", "//:tokenizer", "@highway//:hwy", diff --git a/gemma/activations.h b/gemma/activations.h index 7563617..7874423 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -16,6 +16,7 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#include // sqrtf #include #include @@ -29,6 +30,16 @@ namespace gcpp { +// Returns the scale value to use for the query in the attention computation. +// Also called by ops_test. +static inline float ChooseQueryScale(const ModelConfig& config) { + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / sqrtf(static_cast(config.model_dim / + config.layer_configs[0].heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); +} + struct Activations { Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env) : weights_config(config), @@ -36,6 +47,7 @@ struct Activations { seq_len(config.seq_len), cache_pos_size(config.CachePosSize()), is_griffin(config.model == Model::GRIFFIN_2B), + query_scale(ChooseQueryScale(config)), x("x", Extents2D(batch_size, config.model_dim), pad_), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA @@ -129,6 +141,7 @@ struct Activations { size_t seq_len; size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. bool is_griffin = false; + float query_scale; const Extents2D none_ = Extents2D(); const MatPadding pad_ = MatPadding::kOdd; diff --git a/gemma/attention.cc b/gemma/attention.cc new file mode 100644 index 0000000..c79ac89 --- /dev/null +++ b/gemma/attention.cc @@ -0,0 +1,346 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/weights.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/attention.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Computes Q.K scores, which are "logits" (or scores) stored to att. +// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, + const float* HWY_RESTRICT q, + const MatPtrT& k, float* HWY_RESTRICT att) { + if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { + // Slightly faster: no wraparound. + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const float score = Dot(q, k.Row(pos), k.Cols()); + att[pos] = score; + } + } else { + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const size_t pos_modulo = div_seq_len.Remainder(pos); + const float score = Dot(q, k.Row(pos_modulo), k.Cols()); + att[pos_modulo] = score; + } + } +} + +template +static void PositionalEncodingQK(U* qk, const size_t qkv_dim, + const size_t layer_idx, + const LayerWeightsPtrs& layer, + const Activations& activations, + const size_t pos, const float mul = 1.0f) { + const PostQKType& post_qk = layer.layer_config.post_qk; + // qk is either q or k, so qkv_dim is the length we operate on. + const float* inv_timescale = activations.inv_timescale.PackedScale1(); + bool is_global_layer = + activations.weights_config.attention_window_sizes[layer_idx] == + activations.seq_len; + // TODO: add a config flag instead of hardcoding the model. + if (is_global_layer && IsVLM(activations.weights_config.model)) { + inv_timescale = activations.inv_timescale_global.PackedScale1(); + } + // PostQKType::Rope + if (post_qk == PostQKType::HalfRope) { + Rope(qk, qkv_dim / 2, inv_timescale, pos); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); + } else { + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos); + } +} + +// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into +// `att_out`. Equivalent in gemma/modules.py: +// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) +// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void WeightedSumV(const size_t start_pos, + const size_t last_pos, + const hwy::Divisor& div_seq_len, + const float* HWY_RESTRICT att, + const MatPtrT& v, + float* HWY_RESTRICT att_out) { + const size_t qkv_dim = v.Cols(); + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); + + if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { + // Slightly faster: no wraparound. + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); + } + } else { + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const size_t pos_modulo = div_seq_len.Remainder(pos); + const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo); + MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols()); + } + } +} + +// Calculates the attention outputs for a single q. +void SingleDotSoftmaxWeightedSum( + const size_t pos, const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, + const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, + const LayerWeightsPtrs& layer, const Activations& activations, + float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) { + const size_t qkv_dim = layer.layer_config.qkv_dim; + const float att_cap = activations.weights_config.att_cap; + const float query_scale = activations.query_scale; + + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim); + }); + } + PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos, + query_scale); + + QDotK(start_pos, last_pos, div_seq_len, q, k, att); + + // SoftMax with optional SoftCap yields "probabilities" in att. + const size_t att_len = + HWY_MIN(last_pos + 1, static_cast(div_seq_len.GetDivisor())); + MaybeLogitsSoftCap(att_cap, att, att_len); + Softmax(att, att_len); + + WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out); +} + +// The attention window usually starts at 0 unless `pos` is larger than +// the attention window size, then it is `pos` - window_size + 1. +static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, + size_t layer_idx) { + const size_t att_window_size = config.attention_window_sizes[layer_idx]; + return pos - HWY_MIN(att_window_size - 1, pos); +} + +void DotSoftmaxWeightedSum( + const size_t num_tokens, const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, + const size_t layer_idx, const LayerWeightsPtrs& layer, + Activations& activations, const KVCaches& kv_caches, NestedPools& pools) { + const size_t num_queries = queries_pos.size(); + const LayerConfig& layer_config = layer.layer_config; + PROFILER_ZONE("Gen.Attention.DotSoftmax"); + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t cache_pos_size = activations.cache_pos_size; + + // For each head (token, query), compute Q.K, softmax, and weighted V. + // TODO: nested parallelism to use more threads. + pools.Pool(0).Run( + 0, layer_config.heads * num_tokens * num_queries, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % layer_config.heads; + const size_t interleaved_idx = task / layer_config.heads; + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + + float* HWY_RESTRICT q = + activations.q.Row(interleaved_idx) + head * qkv_dim; + float* HWY_RESTRICT att = + activations.att.Row(interleaved_idx) + head * activations.seq_len; + float* HWY_RESTRICT att_out = + activations.att_out.Row(interleaved_idx) + head * qkv_dim; + + // Make strided views into the kv cache entries for the current + // query and head. + KVCache& kv_cache = kv_caches[query_idx]; + const size_t kv_head_offset = + layer_idx * cache_layer_size + head_offset; + MatPtrT k("k_view", Extents2D(kv_cache.seq_len, qkv_dim)); + k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, + /*stride=*/cache_pos_size); + MatPtrT v("v_view", Extents2D(kv_cache.seq_len, qkv_dim)); + v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, + /*stride=*/cache_pos_size); + + // Find the token position in the query and calculate the range + // of cache positions to attend to. + const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t start_pos = + StartPos(pos, activations.weights_config, layer_idx); + size_t last_pos = pos; + const size_t prefix_end = queries_prefix_end[query_idx]; + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } + + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k, + v, layer_idx, layer, activations, att, + att_out); + }); +} + +// Fills activations.q and writes to KV cache. +static HWY_INLINE void ComputeQKV( + size_t num_tokens, const QueriesPos& queries_pos, + const hwy::Divisor& div_seq_len, const size_t layer_idx, + const LayerWeightsPtrs& layer, Activations& activations, + const KVCaches& kv_caches, const int flags, NestedPools& pools) { + PROFILER_ZONE("Gen.Attention.QKV"); + const size_t num_queries = queries_pos.size(); + const size_t num_interleaved = num_tokens * num_queries; + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t kv_heads = layer_config.kv_heads; + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t cache_pos_size = activations.cache_pos_size; + + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1, + /*add=*/nullptr, *activations.env, activations.q); + + // Set up MatMul row pointers for writing to KV, which consists of + // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound + // because rows are computed modulo seq_len. + MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), + layer.qkv_einsum_w2.Rows())); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; + const size_t cache_pos = + div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); + const size_t kv_offset = + cache_pos * cache_pos_size + layer_idx * cache_layer_size; + activations.env->storage.OutRow(interleaved_idx) = + reinterpret_cast(kv_caches[query_idx].kv_cache.get() + + kv_offset); + } + kv_rows.AttachRowPtrs(&activations.env->storage.OutRow(0)); + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, + /*add=*/nullptr, *activations.env, kv_rows); + + // Apply positional encodings for K. + // TODO: 2D parallelism to use more threads. + pools.Pool(0).Run( + 0, kv_heads * num_interleaved, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % kv_heads; + const size_t interleaved_idx = task / kv_heads; + const size_t query_idx = interleaved_idx % num_queries; + const size_t batch_idx = interleaved_idx / num_queries; + const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t cache_pos = div_seq_len.Remainder(pos); + const size_t kv_offset = cache_pos * cache_pos_size + + layer_idx * cache_layer_size + + head * qkv_dim * 2; + KVCache& kv_cache = kv_caches[query_idx]; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + + // Apply further processing to K. + if (layer.key_norm_scale.HasPtr()) { + CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim); + }); + } + + PositionalEncodingQK(kv, qkv_dim, layer_idx, layer, activations, pos); + }); +} + +// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and +// head_dim (`qkv_dim`) into output (`layer_out`). +static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, + Activations& activations) { + PROFILER_ZONE("Gen.Attention.SumHeads"); + const LayerConfig& layer_config = layer.layer_config; + // att_weights and att_out are concatenated heads, each of length + // layer_config.qkv_dim. Thus the [num_interleaved, + // layer_config.model_dim] matmul output is the sum over heads. Compare + // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', + // encoded) + HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && + layer_config.qkv_dim != 0); + const float* add = layer_config.softmax_attn_output_biases + ? layer.attention_output_biases.PackedScale1() + : nullptr; + CallMatMul(activations.att_out, layer.att_weights, add, *activations.env, + activations.att_sums); +} + +// `queries_prefix_end` can be null (interpreted as all-zero) for standard +// causal attention, and must be non-null for prefix-LM style attention. +void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, + const QueriesPos* queries_prefix_end, + const hwy::Divisor& div_seq_len, const size_t layer_idx, + const LayerWeightsPtrs& layer, Activations& activations, + const KVCaches& kv_caches, int flags) { + const size_t num_queries = queries_pos.size(); + HWY_DASSERT(num_queries <= kv_caches.size()); + + const LayerConfig& layer_config = layer.layer_config; + HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. + HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, + "query heads must be a multiple of key-value heads"); + (void)layer_config; // only used in HWY_DASSERT + + std::vector queries_prefix_end_vec; + QueriesPos queries_prefix_end_span; + if (queries_prefix_end == nullptr) { + queries_prefix_end_vec.assign(num_queries, 0); + queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(), + queries_prefix_end_vec.size()); + queries_prefix_end = &queries_prefix_end_span; + } + + NestedPools& pools = activations.env->ctx.pools; + ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer, + activations, kv_caches, flags, pools); + DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, + div_seq_len, layer_idx, layer, activations, kv_caches, + pools); + SumHeads(layer, activations); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/attention.h b/gemma/attention.h new file mode 100644 index 0000000..c8b527f --- /dev/null +++ b/gemma/attention.h @@ -0,0 +1,63 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ + +// Declares GemmaAttention for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void SingleDotSoftmaxWeightedSum( \ + const size_t pos, const size_t start_pos, const size_t last_pos, \ + const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, size_t layer_idx, \ + const LayerWeightsPtrs& layer, const Activations& activations, \ + float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \ + \ + void DotSoftmaxWeightedSum(const size_t num_tokens, \ + const QueriesPos& queries_pos, \ + const QueriesPos& queries_prefix_end, \ + const hwy::Divisor& div_seq_len, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + Activations& activations, \ + const KVCaches& kv_caches, NestedPools& pools); \ + \ + void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \ + const QueriesPos* queries_prefix_end, \ + const hwy::Divisor& div_seq_len, const size_t layer_idx, \ + const LayerWeightsPtrs& layer, Activations& activations, \ + const KVCaches& kv_caches, int flags); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_ATTENTION) + +#undef GEMMA_DECL_ATTENTION + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ diff --git a/gemma/common.cc b/gemma/common.cc index 76b90b5..00cb05d 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -15,15 +15,11 @@ #include "gemma/common.h" -#include // sqrtf #include #include -#include #include "gemma/configs.h" -#include "util/basics.h" // BF16 -#include "hwy/base.h" // ConvertScalarTo namespace gcpp { @@ -39,30 +35,4 @@ void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) { } } -float EmbeddingScaling(size_t model_dim) { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo( - hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); -} - -float ChooseQueryScale(const ModelConfig& config) { - if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / sqrtf(static_cast(config.model_dim / - config.layer_configs[0].heads)); - // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); -} - -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, const size_t prompt_size) { - if (!weights_config.use_local_attention) { - if (max_generated_tokens > weights_config.seq_len) { - HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.", - max_generated_tokens, weights_config.seq_len); - max_generated_tokens = weights_config.seq_len; - } - } - HWY_ASSERT(prompt_size > 0); -} - } // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index 934c6a7..37e903d 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -28,15 +28,6 @@ namespace gcpp { // DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine. void Wrap(const ModelConfig& config, size_t pos, std::string& prompt); -// Returns the scale value to use for the embedding (basically sqrt model_dim). -float EmbeddingScaling(size_t model_dim); - -// Returns the scale value to use for the query in the attention computation. -float ChooseQueryScale(const ModelConfig& config); - -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, size_t prompt_size); - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0c946e3..9af7a16 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -13,31 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SIMD functions for Gemma/Griffin transformers. +// Transformer components shared between vit.cc and attention.cc. -#include // sqrtf #include #include -#include -#include // std::min -#include // std::iota -#include - -#include "gemma/activations.h" -#include "gemma/common.h" // EmbeddingScaling #include "gemma/configs.h" -#include "gemma/gemma.h" -#include "gemma/kv_cache.h" -#include "gemma/weights.h" -#include "ops/matmul_static.h" -#include "paligemma/image.h" #include "util/mat.h" -#include "util/threading_context.h" -#include "hwy/aligned_allocator.h" // Span -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" -#include "hwy/timer.h" // Include guard (still compiled once per target) #if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ @@ -50,686 +33,12 @@ #include "hwy/highway.h" // After highway.h -#include "ops/matvec-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, - const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C) { - return CallUpcasted( - &B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); }); -} - -// Different functions use different naming conventions for the number of -// tokens. Functions that are query-independent, such as RMSNorm*, call the -// count `num_interleaved`. Functions that are query-dependent, such as -// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the -// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. - -static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, - size_t num_tokens, - size_t griffin_layer, - Activations& activations, - const LayerWeightsPtrs* layer_weights, - const KVCaches& kv_caches) { - PROFILER_ZONE("Gen.Griffin"); - hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D df; - - const size_t model_dim = layer_weights->layer_config.model_dim; - HWY_DASSERT(model_dim % hn::Lanes(df) == 0); - - const size_t heads = layer_weights->layer_config.heads; - const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; - HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); - const size_t kHeadDim = model_dim / heads; - const size_t kMatrixSize = kHeadDim * kHeadDim; - - const size_t num_queries = queries_pos.size(); - const hwy::Divisor div_num_q(static_cast(num_queries)); - const size_t num_interleaved = num_tokens * num_queries; - - // X / Y linear layers. - // TODO: MatMul - HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); - HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); - CallUpcastedSame( - &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, - [&](const auto* wx, const auto* wy) { - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT y = activations.griffin_y.Row(r); - float* HWY_RESTRICT x = activations.griffin_x.Row(r); - TwoMatVecAdd( - *wx, *wy, 0, model_dim, model_dim, - activations.pre_att_rms_out.Row(r), - /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), - /*out0=*/x, /*out1=*/y, pool); - Gelu(y, model_dim); - } - }); - - // Conv1D. - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t query_idx = div_num_q.Remainder(interleaved_idx); - const size_t batch_idx = div_num_q.Divide(interleaved_idx); - const size_t pos = queries_pos[query_idx] + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); - - // cache[i] = input at time t-i. - float* HWY_RESTRICT cache[kMaxConv1DWidth]; - cache[0] = x; - for (size_t i = 1; i < conv_1d_width; i++) { - cache[i] = - kv_caches[query_idx].conv1d_cache.Row(griffin_layer) + - ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; - } - for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { - auto xv = hn::Load(df, x + i); - auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); - auto accum1 = hn::Zero(df); - for (size_t l = 0; 2 * l < conv_1d_width; l++) { - auto wv0 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 1 - 2 * l) * model_dim + i); - auto wv1 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 2 - 2 * l) * model_dim + i); - accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); - accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); - } - hn::Store(hn::Add(accum0, accum1), df, x + i); - hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); - } - } - - // RGLRU - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t query_idx = div_num_q.Remainder(interleaved_idx); - const size_t batch_idx = div_num_q.Divide(interleaved_idx); - const size_t pos = queries_pos[query_idx] + batch_idx; - - float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); - float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx); - float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx); - float* HWY_RESTRICT rnn_state = - kv_caches[query_idx].rglru_cache.Row(griffin_layer); - - pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - size_t head_offset = head * kHeadDim; - CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { - TwoOfsMatVecAddLoop( - *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, - kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + - head_offset, - /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + - model_dim + head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - }); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.PackedScale1() + head_offset, - fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0f); - } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); - } // interleaved_idx - - // Final linear layer. - CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, - layer_weights->griffin.linear_out_biases.PackedScale1(), - *activations.env, activations.att_sums); -} // GriffinRecurrent - -// Wrapper class; holds arguments in member variables to shorten call sites. -class GemmaAttention { - // The attention window usually starts at 0 unless `pos` is larger than - // the attention window size, then it is `pos` - window_size + 1. - HWY_INLINE size_t StartPos(size_t pos, size_t layer) { - const size_t att_window_size = - activations_.weights_config.attention_window_sizes[layer]; - return pos - std::min(att_window_size - 1, pos); - } - - template - HWY_INLINE void PositionalEncodingQK(U* qk, size_t pos, size_t layer, - const float mul) { - // qk is either q or k, so qkv_dim is the length we operate on. - const size_t qkv_dim = layer_config_.qkv_dim; - const float* inv_timescale = activations_.inv_timescale.PackedScale1(); - bool is_global_layer = - activations_.weights_config.attention_window_sizes[layer] == - activations_.seq_len; - // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && IsVLM(activations_.weights_config.model)) { - inv_timescale = activations_.inv_timescale_global.PackedScale1(); - } - // PostQKType::Rope - (void)layer; - if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { - Rope(qk, qkv_dim / 2, inv_timescale, pos); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); - } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos); - } - } - - // Fills activations.q and writes to KV cache. - HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) { - PROFILER_ZONE("Gen.Attention.QKV"); - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t kv_heads = layer_config_.kv_heads; - - // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, - // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. - CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, - /*add=*/nullptr, *activations_.env, activations_.q); - - // Set up MatMul row pointers for writing to KV, which consists of - // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound - // because rows are computed modulo seq_len. - MatPtrT kv_rows("kv", - Extents2D(activations_.pre_att_rms_out.Rows(), - layer_weights_.qkv_einsum_w2.Rows())); - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t cache_pos = - div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx); - const size_t kv_offset = - cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; - activations_.env->storage.OutRow(interleaved_idx) = - reinterpret_cast(kv_caches_[query_idx].kv_cache.get() + - kv_offset); - } - kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0)); - CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, - /*add=*/nullptr, *activations_.env, kv_rows); - - // Apply positional encodings for K (and copy KV to cache if MHA). - pool_.Run(0, kv_heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % kv_heads; - const size_t interleaved_idx = task / kv_heads; - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = queries_pos_[query_idx] + batch_idx; - const size_t cache_pos = div_seq_len_.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer_ * cache_layer_size_ + - head * qkv_dim * 2; - KVCache& kv_cache = kv_caches_[query_idx]; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - - // Apply further processing to K. - if (layer_weights_.key_norm_scale.HasPtr()) { - CallUpcasted(&layer_weights_.key_norm_scale, - [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, - kv, qkv_dim); - }); - } - PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); - }); - } - - // Computes Q.K scores, which are "logits" (or scores) stored to att. - // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. - HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT q, const MatPtrT& k, - float* HWY_RESTRICT att) { - const size_t qkv_dim = layer_config_.qkv_dim; - if (HWY_LIKELY(last_pos < activations_.seq_len)) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float* HWY_RESTRICT k_ptr = k.Row(pos); - const float score = Dot(q, k_ptr, qkv_dim); - att[pos] = score; - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t cache_pos = div_seq_len_.Remainder(pos); - const float* HWY_RESTRICT k_ptr = k.Row(cache_pos); - const float score = Dot(q, k_ptr, qkv_dim); - att[pos % activations_.seq_len] = score; - } - } - } - - // Accumulates the sum of v (from `kv_cache`) * probability (`att`) into - // `att_out`. Equivalent in gemma/modules.py: - // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. - HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT att, - const MatPtrT& v, - float* HWY_RESTRICT att_out) const { - const size_t qkv_dim = layer_config_.qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - - if (HWY_LIKELY(last_pos < activations_.seq_len)) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float* HWY_RESTRICT v_ptr = v.Row(pos); - MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim); - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t cache_pos = div_seq_len_.Remainder(pos); - const float* HWY_RESTRICT v_ptr = v.Row(cache_pos); - MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out, - qkv_dim); - } - } - } - - public: - // Calculates the attention outputs for a single q. - HWY_INLINE void SingleDotSoftmaxWeightedSum( - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, - const float query_scale, const size_t pos, const size_t start_pos, - const size_t last_pos) { - const size_t qkv_dim = layer_config_.qkv_dim; - - // Apply rope and scaling to Q. - if (layer_weights_.query_norm_scale.HasPtr()) { - CallUpcasted(&layer_weights_.query_norm_scale, - [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim); - }); - } - PositionalEncodingQK(q, pos, layer_, query_scale); - - QDotK(start_pos, last_pos, q, k, att); - - // SoftMax with optional SoftCap yields "probabilities" in att. - const size_t att_len = std::min(last_pos + 1, activations_.seq_len); - MaybeLogitsSoftCap(activations_.weights_config.att_cap, att, att_len); - Softmax(att, att_len); - - WeightedSumV(start_pos, last_pos, att, v, att_out); - } - - HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) { - PROFILER_ZONE("Gen.Attention.DotSoftmax"); - const float query_scale = ChooseQueryScale(activations_.weights_config); - - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads; - - // For each head (token, query), compute Q.K, softmax, and weighted V. - pool_.Run( - 0, layer_config_.heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t interleaved_idx = task / layer_config_.heads; - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - - float* HWY_RESTRICT q = - activations_.q.Row(interleaved_idx) + head * qkv_dim; - float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) + - head * activations_.seq_len; - float* HWY_RESTRICT att_out = - activations_.att_out.Row(interleaved_idx) + head * qkv_dim; - - // Make strided views into the kv cache entries for the current - // query and head. - KVCache& kv_cache = kv_caches_[query_idx]; - const size_t kv_head_offset = - layer_ * cache_layer_size_ + head_offset; - MatPtrT k("k_view", Extents2D(kv_cache.seq_len, qkv_dim)); - k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, - /*stride=*/cache_pos_size_); - MatPtrT v("v_view", Extents2D(kv_cache.seq_len, qkv_dim)); - v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, - /*stride=*/cache_pos_size_); - - // Find the token position in the query and calculate the range - // of cache positions to attend to. - const size_t pos = queries_pos_[query_idx] + batch_idx; - const size_t start_pos = StartPos(pos, layer_); - size_t last_pos = pos; - const size_t prefix_end = queries_prefix_end_[query_idx]; - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } - - SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos, - start_pos, last_pos); - }); - } - - private: - // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim (`qkv_dim`) into output (`layer_out`). - HWY_NOINLINE void SumHeads() { - PROFILER_ZONE("Gen.Attention.SumHeads"); - // att_weights and att_out are concatenated heads, each of length - // layer_config_.qkv_dim. Thus the [num_interleaved, - // layer_config_.model_dim] matmul output is the sum over heads. Compare - // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', - // encoded) - HWY_DASSERT(layer_config_.model_dim != 0 && layer_config_.heads != 0 && - layer_config_.qkv_dim != 0); - HWY_DASSERT(layer_weights_.att_weights.HasPtr()); - HWY_DASSERT(activations_.att_out.HasPtr()); - HWY_DASSERT(activations_.att_sums.HasPtr()); - - const float* add = - layer_weights_.layer_config.softmax_attn_output_biases - ? layer_weights_.attention_output_biases.PackedScale1() - : nullptr; - CallMatMul(activations_.att_out, layer_weights_.att_weights, add, - *activations_.env, activations_.att_sums); - } - - public: - // Constructor with explicit initialization of queries_prefix_end. This is - // needed for the Prefix-LM style attention. For standard causal attention, - // the other constructor can be used. - GemmaAttention(const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) - : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches, - activations.env->ctx) {} - // Constructor with default initialization to 0 for queries_prefix_end. - GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer, - Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) - : GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations, - layer_weights, div_seq_len, kv_caches, - activations.env->ctx) {} - // Constructor with an explicit ThreadingContext. This is needed for - // experimental code that invokes methods that do not use `activations.env`. - // Callers should not have to construct an `activations.env` just to pass in - // the threading context. - GemmaAttention(const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - ThreadingContext& ctx) - : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches, - ctx) {} - - // Full attention computation in three steps. - HWY_INLINE void operator()() { - const size_t num_interleaved = num_tokens_ * num_queries_; - ComputeQKV(num_interleaved); - DotSoftmaxWeightedSum(num_interleaved); - SumHeads(); - } - - private: - // Delegated Constructor that does most of the common work. - GemmaAttention(const QueriesPos& queries_pos, - const QueriesPos* queries_prefix_end, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - ThreadingContext& ctx) - : queries_pos_(queries_pos), - num_queries_(queries_pos.size()), - num_tokens_(num_tokens), - layer_(layer), - layer_config_(layer_weights->layer_config), - cache_layer_size_(layer_weights->layer_config.CacheLayerSize()), - cache_pos_size_(activations.cache_pos_size), - activations_(activations), - layer_weights_(*layer_weights), - div_seq_len_(div_seq_len), - kv_caches_(kv_caches), - pool_(ctx.pools.Pool(0)) { - HWY_DASSERT(!layer_config_.IsMHA()); // No longer supported. - HWY_DASSERT(num_queries_ <= kv_caches_.size()); - HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, - "query heads must be a multiple of key-value heads"); - if (queries_prefix_end != nullptr) { - queries_prefix_end_ = *queries_prefix_end; - } else { - queries_prefix_end_vec_.assign(num_queries_, 0); - queries_prefix_end_ = QueriesPos(queries_prefix_end_vec_.data(), - queries_prefix_end_vec_.size()); - } - } - - const QueriesPos& queries_pos_; - std::vector queries_prefix_end_vec_; - QueriesPos queries_prefix_end_; - const size_t num_queries_; - const size_t num_tokens_; - const size_t layer_; - const LayerConfig& layer_config_; - const size_t cache_layer_size_ = 0; - const size_t cache_pos_size_ = 0; - - Activations& activations_; - const LayerWeightsPtrs& layer_weights_; - const hwy::Divisor& div_seq_len_; - const KVCaches& kv_caches_; - hwy::ThreadPool& pool_; -}; - -static HWY_NOINLINE void Attention( - LayerAttentionType type, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer, - Activations& activations, const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { - if (type == LayerAttentionType::kGemma) { - GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches)(); - } else { - HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); - // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, - // so map `layer` to the Griffin layer index. - const size_t griffin_layer = - activations.weights_config.NumLayersOfTypeBefore(type, layer); - GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, - layer_weights, kv_caches); - } -} - -// Wrapper class; holds arguments in member variables to shorten call sites. -// The main differences to GemmaAttention are: -// - no KV Cache necessary, attention is always all-to-all and not causal. -// - no potential wrap-around, attention always goes from 0 to kSeqLen. -// - no need for batching, as we are always computing attention for kSeqLen -// tokens. -// This results in a much simpler implementation. However, to avoid duplicating -// code, we should still consider merging the two classes. -// TODO(keysers): Refactor to share code with GemmaAttention. -class VitAttention { - // Computes Q, K, V for all heads, stored in activations_.q. - HWY_NOINLINE void ComputeQKV() { - PROFILER_ZONE("Gen.VitAttention.QKV"); - auto& qkv = activations_.q; - HWY_ASSERT(qkv.Rows() == num_tokens_); - HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - CallMatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, - layer_weights_.vit.qkv_einsum_b.PackedScale1(), - *activations_.env, qkv); - } - - // TODO(philculliton): transition fully to MatMul. - HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; - const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - - // Shift Q, K, VT to MatStorageT. - MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), - MatPadding::kPacked); - MatStorageT K("K2", Extents2D(seq_len, qkv_dim), - MatPadding::kPacked); - MatStorageT C("C2", Extents2D(num_tokens_, seq_len), - MatPadding::kPacked); - - // Initialize att_out to zero prior to head loop. - ZeroInit(activations_.att_out); - - for (size_t head = 0; head < heads; ++head) { - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t token = task; - float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim; - // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim); - hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); - }); - - pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t seq_idx = task; - float* HWY_RESTRICT k = - activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim; - hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); - }); - - // this produces C, a (num_tokens_, seq_len) matrix of dot products - CallMatMul(Q, K, nullptr, *activations_.env, C); - - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - float* HWY_RESTRICT c = C.Row(task); - Softmax(c, C.Cols()); - }); - - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - size_t token = task; - float* HWY_RESTRICT att_out = - activations_.att_out.Row(token) + head * qkv_dim; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = - activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); - } - }); - } - } - - HWY_NOINLINE void DotSoftmaxWeightedSum() { - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; - const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - - // Compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_tokens_, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t token = task / layer_config_.heads; - // Compute Q.K scores, which are "logits" stored in head_att. - float* HWY_RESTRICT q = - activations_.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim); - float* HWY_RESTRICT head_att = - activations_.att.Row(token) + head * activations_.seq_len; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT k = - activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; - head_att[i] = Dot(q, k, qkv_dim); // score = q.k - } - // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len); - // Compute weighted sum of v into att_out. - float* HWY_RESTRICT att_out = - activations_.att_out.Row(token) + head * qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = - activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); - } - }); - } - - // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim (`qkv_dim`) into output (`att_sums`). - HWY_NOINLINE void SumHeads() { - PROFILER_ZONE("Gen.VitAttention.SumHeads"); - auto* bias = layer_weights_.vit.attn_out_b.PackedScale1(); - // att_weights and att_out are concatenated heads, each of length - // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] - // matmul output is the sum over heads. - CallMatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, - *activations_.env, activations_.att_sums); - } - - public: - VitAttention(size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights) - : num_tokens_(num_tokens), - activations_(activations), - layer_weights_(*layer_weights), - layer_config_(layer_weights->layer_config), - pool_(activations.env->ctx.pools.Pool(0)) {} - - HWY_INLINE void operator()() { - ComputeQKV(); - if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { - DotSoftmaxWeightedSumMatrix(); - } else { - DotSoftmaxWeightedSum(); - } - SumHeads(); - } - - private: - const size_t num_tokens_; - Activations& activations_; - const LayerWeightsPtrs& layer_weights_; - const LayerConfig& layer_config_; - hwy::ThreadPool& pool_; -}; - template HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, const T* HWY_RESTRICT c2, size_t count) { @@ -760,7 +69,8 @@ void ActivationBatched(ActivationType activation, Mat& c1) { } template -void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) { +HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, + const Mat* c2) { using T = typename Mat::T; HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { @@ -775,117 +85,10 @@ void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) { } } -static HWY_NOINLINE void FFWNoVit(Activations& activations, - const LayerWeightsPtrs* layer_weights) { - PROFILER_ZONE("Gen.FFW"); - const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - - const bool add_bias = layer_weights->layer_config.ff_biases; - const float* bias1 = - add_bias ? layer_weights->ffw_gating_biases.PackedScale1() : nullptr; - const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; - const float* output_bias = - add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; - - // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, - bias1, *activations.env, activations.C1); - CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, - bias2, *activations.env, activations.C2); - - // Activation (Gelu) and maybe multiply by gate. Store activations in act. - ActivationBatched(layer_weights->layer_config.activation, activations.C1, - &activations.C2); - - // Hidden layer -> output layer. - CallMatMul(activations.C1, layer_weights->linear_w, output_bias, - *activations.env, activations.ffw_out); -} - -// Same as FFWNoVit, but with different layer_weights members and no second -// gating matrix. -static HWY_NOINLINE void FFWVit(Activations& activations, - const LayerWeightsPtrs* layer_weights) { - PROFILER_ZONE("Gen.FFW.ViT"); - - const bool add_bias = layer_weights->layer_config.ff_biases; - const float* bias1 = - add_bias ? layer_weights->vit.linear_0_b.PackedScale1() : nullptr; - const float* output_bias = - add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; - - // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, - *activations.env, activations.C1); - - // Activation (Gelu), store in C1. - ActivationBatched(layer_weights->layer_config.activation, activations.C1); - - // Hidden layer -> output layer. - CallMatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, - *activations.env, activations.ffw_out); -} - -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. -// -// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 -// spec) until we run out of image tokens. This allows for a multi-image prompt -// if -2 locations with appropriate begin/end image tokens are created by the -// calling application. -// Returns new image_token_position. -static HWY_NOINLINE size_t -EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, - const ModelConfig& model_config, const ModelWeightsPtrs& weights, - MatStorageT& x, const ImageTokens* image_tokens = nullptr, - size_t image_token_position = 0) { - // Image tokens just need to be copied. - if (model_config.wrapping == PromptWrapping::GEMMA_VLM && - image_tokens != nullptr && token == -2 && - image_token_position < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), - x.Cols() * x.ElementBytes()); - return image_token_position + 1; - } - - if (model_config.wrapping == PromptWrapping::PALIGEMMA && - image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), - x.Cols() * x.ElementBytes()); - return image_token_position; - } - - const size_t model_dim = model_config.model_dim; - const float emb_scaling = EmbeddingScaling(model_dim); - - HWY_DASSERT(token >= 0); - HWY_DASSERT(token < static_cast(model_config.vocab_size)); - - CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) { - // Using `Stride` to compute the offset works for both NUQ (because we use - // an offset and NUQ is never padded) and padded, because non-NUQ types are - // seekable, hence the offset can also skip any padding. - const size_t embedding_ofs = token * weights_t->Stride(); - HWY_ASSERT(weights_t->Cols() == model_dim); - const auto embedding_span = - MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); - const hn::ScalableTag df; - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), - model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim); - }); - - if (model_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); - } - return image_token_position; -} - template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, - const LayerWeights* layer_weights, + const LayerWeights& layer, bool is_attention) { // ResidualType::Add AddFromBatched(other, x); @@ -900,588 +103,31 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights, } } -static HWY_NOINLINE void TransformerLayer( - const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - size_t num_tokens, size_t cache_layer_idx, - const LayerWeightsPtrs* layer_weights, Activations& activations, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { - auto type = layer_weights->layer_config.type; - - RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale, - activations.pre_att_rms_out); - - Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx, - activations, layer_weights, div_seq_len, kv_caches); - - PostNorm(layer_weights->layer_config.post_norm, - layer_weights->post_attention_norm_scale, activations.att_sums); - - ResidualConnection(activations.att_sums, activations.x, layer_weights, - /*is_attention=*/true); - - RMSNormBatched(activations.x, layer_weights->pre_ffw_norm_scale, - activations.pre_ffw_rms_out); - - if (layer_weights->layer_config.type == LayerAttentionType::kVit) { - FFWVit(activations, layer_weights); - } else { - FFWNoVit(activations, layer_weights); - } - - PostNorm(layer_weights->layer_config.post_norm, - layer_weights->post_ffw_norm_scale, activations.ffw_out); - - ResidualConnection(activations.ffw_out, activations.x, layer_weights, - /*is_attention=*/false); -} - -// Vit transformer layer. Some comments below refer to the Vit implementation in -// the Big Vision codebase. See -// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py -// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and -// try merging this with TransformerLayer. -static HWY_NOINLINE void VitTransformerLayer( - size_t num_tokens, size_t layer, const LayerWeightsPtrs* layer_weights, - Activations& activations) { - const size_t model_dim = activations.weights_config.model_dim; - auto type = layer_weights->layer_config.type; - HWY_DASSERT(type == LayerAttentionType::kVit); - (void)type; - (void)model_dim; - - auto& x = activations.x; - HWY_DASSERT(x.Rows() == num_tokens); - HWY_DASSERT(x.Cols() == model_dim); - - // y = nn.LayerNorm()(x) - // y ~ pre_att_rms_out - LayerNormBatched(x, layer_weights->vit.layer_norm_0_scale, - layer_weights->vit.layer_norm_0_bias, - activations.pre_att_rms_out); - - // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) - // y ~ att_sums - VitAttention(num_tokens, layer, activations, layer_weights)(); - - // x = out["+sa"] = x + y - AddFromBatched(activations.att_sums, x); - - // y = nn.LayerNorm()(x) - // y ~ pre_ffw_rms_out - LayerNormBatched(x, layer_weights->vit.layer_norm_1_scale, - layer_weights->vit.layer_norm_1_bias, - activations.pre_ffw_rms_out); - - // y = out["mlp"] = MlpBlock(...)(y) - // y ~ ffw_out - FFWVit(activations, layer_weights); - - // x = out["+mlp"] = x + y - AddFromBatched(activations.ffw_out, x); -} - -// Prefill() and Transformer() increment positions in-place. -using QueriesMutablePos = hwy::Span; - -// Populates KV cache for batches of tokens from one query at a time. -static HWY_NOINLINE void Prefill( - const QueriesPromptTokens& queries_prompt, - const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches) { - PROFILER_ZONE("Gen.Prefill"); - const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); - HWY_DASSERT(kv_caches.size() == num_queries); - - // Batches are important for amortizing loading weights over multiple tokens. - // This is possible in prefill because we know all tokens beforehand, whereas - // decode depends on the previous output token. However, each prefill batch of - // a query requires that preceding batches already wrote to the KV cache, - // hence we sequentially loop over token batches. We can reduce the number of - // iterations by increasing the batch size, but this also increases arithmetic - // intensity, and so we are eventually compute-limited. We could devote some - // threads to parallelizing over queries, but for simplicity we assign them - // all to MatMul. - const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; - - // For each query. `qi` is within the batch, not the global query index. - for (size_t qi = 0; qi < num_queries; ++qi) { - // Single query at a time, so pass slices of the spans because - // GemmaAttention will only access the first KV cache and position. - QueriesPos single_query_pos(&queries_pos[qi], 1); - QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1); - KVCaches single_kv_cache(&kv_caches[qi], 1); - - const size_t prompt_size = queries_prompt[qi].size(); - // In autoregressive mode, we don't need to prefill the last token, so - 1. - size_t prefill_this_query = prompt_size - 1; - const size_t prefix_end_this_query = queries_prefix_end[qi]; - // We can't attend beyond the prompt_size. - HWY_ASSERT(prefix_end_this_query <= prompt_size); - // Special case: if the prefix includes the last token, we need to prefill - // the last token, too. However, we need to rewind this for the generation - // of the first token. So we need to keep track of this. - // TODO: consider implementing masking instead of this logic? - const bool attend_to_last_token = - (prefill_this_query < prefix_end_this_query); - if (attend_to_last_token) { - // The difference can be at most 1. - prefill_this_query += 1; - HWY_ASSERT(prefill_this_query == prefix_end_this_query); - } - // In prefix-LM mode, we need to look at all the tokens for the prefix in - // one iteration through the layers, so we need a large enough batch size. - HWY_ASSERT(prefix_end_this_query == 0 || - max_tbatch_size >= prefill_this_query); - - // For each batch of tokens in the query: - for (size_t tbatch_start = 0; tbatch_start < prefill_this_query; - tbatch_start += max_tbatch_size) { - const size_t tbatch_size = - HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); - activations.SetBatchSize(tbatch_size); - - // Fill activations.x (much faster than TransformerLayer). - size_t image_token_position = 0; - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; - image_token_position = EmbedMMToken( - token, ti, pos, pos_in_prompt, config, weights, activations.x, - runtime_config.image_tokens, image_token_position); - } - - // Transformer with one batch of tokens from a single query. - for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); - TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size, - layer, layer_weights, activations, div_seq_len, - single_kv_cache); - } - - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; - if (pos_in_prompt < prompt_size - 1) { - runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); - } else { - // The last token will be streamed later and we should only get here - // if we need to attend to the last token because it is in the prefix. - HWY_ASSERT(attend_to_last_token); - } - } - - queries_pos[qi] += tbatch_size; - } // for tbatch_start - if (attend_to_last_token) { - // We need to rewind the position for the last token that we only - // attended to to make sure the prefix LM sees everything. - // This means we duplicate work on the last prompt token in autoregressive - // decoding. Alternatives: (1) real masking; (2) always prefill the last - // token and only generate the next one from the already prefilled - // activations. - queries_pos[qi] -= 1; - } - } -} - -// Gets the patches of the image and embeds them with the image embedding -// kernel. The result is stored in activations.x. -static HWY_NOINLINE void EmbedImagePatches(const Image& image, - const ModelConfig& model_config, - const ModelWeightsPtrs& weights, - Activations& activations) { - const size_t model_dim = model_config.vit_config.model_dim; - const size_t patch_width = model_config.vit_config.patch_width; - const size_t seq_len = model_config.vit_config.seq_len; - const size_t patch_size = patch_width * patch_width * 3; - HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); - HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); - HWY_DASSERT(activations.x.Cols() == model_dim); - // img/embedding/kernel has original shape (14, 14, 3, 1152) - // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) - // image_patches is (256, 14 * 14 * 3) - // Must be padded, see `DoDecompressA`. - MatStorageT image_patches("patches", Extents2D(seq_len, patch_size), - MatPadding::kOdd); - for (size_t i = 0; i < seq_len; ++i) { - image.GetPatch(i, image_patches.Row(i)); - } - CallMatMul(image_patches, weights.vit_img_embedding_kernel, - weights.vit_img_embedding_bias.PackedScale1(), *activations.env, - activations.x); - // Add position embeddings. - CallUpcastedActivation(&weights.vit_img_pos_embedding, - [&](const auto* weights_t) { - AddFromBatched(*weights_t, activations.x); - }); -} - -// Prefills the image tokens with the ViT encoder. -static HWY_NOINLINE void PrefillVit(const ModelConfig& model_config, - const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const Image& image, - ImageTokens& image_tokens, - Activations& activations) { - PROFILER_ZONE("Gen.PrefillVit"); - const size_t num_tokens = model_config.vit_config.seq_len; - const size_t vit_model_dim = model_config.vit_config.model_dim; - HWY_ASSERT(num_tokens == activations.x.Rows()); - // Embed the image patches. - EmbedImagePatches(image, model_config, weights, activations); - // Go through all layers. - for (size_t layer = 0; layer < model_config.vit_config.layer_configs.size(); - ++layer) { - const LayerWeightsPtrs* layer_weights = weights.VitLayer(layer); - VitTransformerLayer(num_tokens, layer, layer_weights, activations); - } - // Final Layernorm. - LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, - weights.vit_encoder_norm_bias, activations.x); - - if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { - activations.x = AvgPool4x4(activations.x); - - // Apply soft embedding norm before input projection. - CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), - vit_model_dim); - }); - } - - // Apply head embedding into image_tokens of size of the LLM kModelDim. - CallMatMul(activations.x, weights.vit_img_head_kernel, - weights.vit_img_head_bias.PackedScale1(), *activations.env, - image_tokens); -} - -// Generates one token for each query. `queries_token` is the previous token -// from each query, and `queries_pos` are their position in the sequence. -static HWY_NOINLINE void Transformer( - const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - const LayersOutputFunc& layers_output, - const ActivationsObserverFunc& activations_observer) { - const size_t num_queries = queries_token.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); - - if (layers_output) { - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - const float token_f = queries_token[query_idx]; - layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f, - 1); - } - } - - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], - /*pos_in_prompt=*/0, config, weights, activations.x); - } - - for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); - TransformerLayer(queries_pos, queries_prefix_end, /*num_tokens=*/1, layer, - layer_weights, activations, div_seq_len, kv_caches); - - if (activations_observer) { - activations_observer(queries_pos, layer, activations); - } - } - - RMSNormInplaceBatched(weights.final_norm_scale, activations.x); - - if (activations_observer) { - activations_observer(queries_pos, -1, activations); - } - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - queries_pos[query_idx] += 1; - } -} - -// Placeholder for internal test3, do not remove - -// Returns the min and max number of tokens for all queries. -static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { - size_t max_prompt_size = 0; - for (size_t i = 0; i < queries_prompt.size(); ++i) { - max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size()); - } - return max_prompt_size; -} - -// Holds "is at end of stream" state for each query. -class TokenStreamer { - public: - explicit TokenStreamer(const RuntimeConfig& runtime_config, - const ModelConfig& model_config) - : runtime_config_(runtime_config), model_config_(model_config) {} - - // Returns whether the query was already at, or has just reached, the end of - // the stream: either via token == eos_id, or StreamToken returning false. - bool operator()(size_t query_idx, size_t pos, int token, float prob) { - if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; - - if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || - model_config_.IsEOS(token)) { - is_eos_.Set(query_idx); - return true; - } - - return false; - } - - private: - const RuntimeConfig& runtime_config_; - const ModelConfig& model_config_; - hwy::BitSet4096<> is_eos_; -}; - -HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { - // If user provided a sample_func, use it. - if (runtime_config.sample_func) return runtime_config.sample_func; - - // Fast path for top-1 with no accept_token. - if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample Top1"); - return Top1OfSoftmax(logits, vocab_size); - }; - } - - // General case: Softmax with top-k sampling. - return [&runtime_config](float* logits, - size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample general"); - return FusedSoftmaxAndSampleTopK( - logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); - }; -} - -// Runs one decode step for all the queries in the batch. Returns true if all -// queries are at . -static bool DecodeStepT(const ModelConfig& config, - const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const size_t query_idx_start, const KVCaches& kv_caches, - const QueriesPos& queries_prefix_end, - const hwy::Divisor div_seq_len, const size_t vocab_size, - const SampleFunc& sample_token, - Activations& activations, TokenStreamer& token_streamer, - std::vector& gen_tokens, TimingInfo& timing_info, - const QueriesMutablePos& queries_mutable_pos) { - const size_t num_queries = queries_prompt.size(); - // Decode generates one token per query and increments - // queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, - queries_prefix_end, config, weights, activations, div_seq_len, - kv_caches, runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. - - HWY_DASSERT(num_queries == activations.x.Rows()); - bool all_queries_eos = true; - { - PROFILER_ZONE("Gen.EmbeddingMatmul"); - // Compute logits from last layer activations. - CallMatMul(activations.x, weights.embedder_input_embedding, - /*add=*/nullptr, *activations.env, activations.logits); - } - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Row(query_idx); - MaybeLogitsSoftCap(config.final_cap, logits, vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(); - - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token; - } - return all_queries_eos; -} - -// Generates one continuation for each query in `queries_prompt`, which is one -// qbatch whose size is at most the `batch_size` passed to -// `activations.Allocate`. -// -// `queries_pos` stores the KV cache position for each query. In the first turn -// of a chat, pos = 0; we increment each query's position after each token. -// -// `query_idx_start` is the query_idx of the first query in the batch, so that -// `StreamFunc` gets the global query index, not relative to the batch. -// -// `kv_caches` is for the batch, size must match `queries_prompt`. -static void GenerateT(const ModelConfig& config, - const ModelWeightsPtrs& weights, Activations& activations, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos_in, - const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const KVCaches& kv_caches, - TimingInfo& timing_info) { - HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); - - // Griffin assumes that the recurrent block cache is zero-initialized. - for (size_t i = 0; i < kv_caches.size(); ++i) { - if (queries_pos_in[i] == 0) { - kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. - } - } - - // Copy so we can increment without requiring users to pass in a mutable span. - std::vector queries_pos_copy(queries_pos_in.cbegin(), - queries_pos_in.cend()); - const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), - queries_pos_copy.size()); - - // Sanity check: prompts should not be empty, nor start with EOS. - for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { - const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); - } - - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. - HWY_ASSERT(num_queries <= activations.x.Rows()); - HWY_ASSERT(queries_pos_in.size() == num_queries); - HWY_ASSERT(kv_caches.size() == num_queries); - const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - size_t max_prompt_size = MaxQueryLength(queries_prompt); - size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(config, max_generated_tokens, max_prompt_size); - const SampleFunc sample_token = ChooseSampleFunc(runtime_config); - - // Prefill stops before min_prompt_size - 1 because the last prompt - // token is the first input token for generation. - timing_info.prefill_start = hwy::platform::Now(); - // Note that Prefill calls activations.SetBatchSize, so we reset it below. - Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, - query_idx_start, config, weights, activations, runtime_config, - div_seq_len, kv_caches); - // Compute the number of tokens that were prefilled and notify timing_info. - size_t prefilled_tokens = 0; - for (size_t qi = 0; qi < num_queries; ++qi) { - prefilled_tokens += queries_prompt[qi].size() - 1; - } - timing_info.NotifyPrefill(prefilled_tokens); - // queries_pos are incremented by Prefill. - activations.SetBatchSize(num_queries); - - // Storage for the last generated token from each query, passed to the next - // Transformer() call. - std::vector gen_tokens(num_queries); - - // Stream the last prompt token from each query and fill gen_tokens. - TokenStreamer token_streamer(runtime_config, config); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - size_t last_token_pos_in_prompt = - queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; - gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt]; - (void)token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], gen_tokens[query_idx], - 0.0f); - } - - { - const size_t vocab_size = config.vocab_size; - timing_info.generate_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - bool all_queries_eos = DecodeStepT( - config, weights, runtime_config, queries_prompt, query_idx_start, - kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, - activations, token_streamer, gen_tokens, timing_info, - queries_mutable_pos); - if (all_queries_eos) break; - } // foreach token to generate - timing_info.NotifyGenerateDone(); - } -} - -static HWY_MAYBE_UNUSED void GenerateSingleT( - const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) { - constexpr size_t kNumQueries = 1; - const size_t qbatch_start = 0; - - const size_t max_batch_size = - HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); - // TODO: move into Gemma? - Activations activations(config, max_batch_size, env); - - const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); - QueriesPos queries_pos(&pos, kNumQueries); - const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); - const KVCaches kv_caches{&kv_cache, kNumQueries}; - - GenerateT(config, weights, activations, runtime_config, queries_prompt, - queries_pos, queries_prefix_end, qbatch_start, kv_caches, - timing_info); -} - -static HWY_MAYBE_UNUSED void GenerateBatchT( - const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, - MatMulEnv* env, TimingInfo& timing_info) { - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(queries_pos.size() == num_queries); - HWY_ASSERT(kv_caches.size() >= num_queries); - const size_t max_qbatch_size = runtime_config.decode_qbatch_size; - const size_t max_batch_size = - HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); - - Activations activations(config, max_batch_size, env); - - for (size_t qbatch_start = 0; qbatch_start < num_queries; - qbatch_start += max_qbatch_size) { - // Generate one batch of tokens from `qbatch_size` queries. - const size_t qbatch_size = - HWY_MIN(num_queries - qbatch_start, max_qbatch_size); - const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], - qbatch_size); - QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); - const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); - const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(config, weights, activations, runtime_config, qbatch_prompts, - qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, - timing_info); - } -} - -static HWY_MAYBE_UNUSED void GenerateImageTokensT( - const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, MatMulEnv* env) { - if (config.vit_config.layer_configs.empty()) { - HWY_ABORT("Model does not support generating image tokens."); - } - RuntimeConfig prefill_runtime_config = runtime_config; - ModelConfig vit_config = GetVitConfig(config); - prefill_runtime_config.prefill_tbatch_size = - vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, vit_config.seq_len, env); - // Weights are for the full PaliGemma model, not just the ViT part. - PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, - prefill_activations); +static inline void FFWNoVit(Activations& activations, + const LayerWeightsPtrs& layer) { + PROFILER_ZONE("Gen.FFW"); + const LayerConfig& layer_config = layer.layer_config; + const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; + + const bool add_bias = layer_config.ff_biases; + const float* bias1 = + add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr; + const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; + const float* output_bias = + add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; + + // Compute the hidden layer activations. + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, + *activations.env, activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, + *activations.env, activations.C2); + + // Activation (Gelu) and maybe multiply by gate. Store activations in act. + ActivationBatched(layer_config.activation, activations.C1, &activations.C2); + + // Hidden layer -> output layer. + CallMatMul(activations.C1, layer.linear_w, output_bias, *activations.env, + activations.ffw_out); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bd37dc5..6aadcaa 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -27,11 +27,15 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h +#include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" +#include "gemma/griffin.h" // includes highway.h +#include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE #define GEMMA_CC_ONCE +#include // sqrtf #include #include #include @@ -49,10 +53,586 @@ #include "ops/matmul.h" #include "paligemma/image.h" #include "util/threading_context.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" +#include "hwy/timer.h" #endif // GEMMA_CC_ONCE +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +void Attention(LayerAttentionType type, size_t num_tokens, + const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, + const hwy::Divisor& div_seq_len, const size_t layer_idx, + const LayerWeightsPtrs& layer, Activations& activations, + const KVCaches& kv_caches) { + if (type == LayerAttentionType::kGemma) { + GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len, + layer_idx, layer, activations, kv_caches, + /*flags=*/0); + } else { + HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); + // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, + // so map `layer` to the Griffin layer index. + const size_t griffin_layer = + activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); + GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, + &layer, kv_caches); + } +} + +static HWY_NOINLINE void TransformerLayer( + const size_t num_tokens, const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, + const size_t layer_idx, const LayerWeightsPtrs& layer, + Activations& activations, const KVCaches& kv_caches) { + const LayerConfig& layer_config = layer.layer_config; + + RMSNormBatched(activations.x, layer.pre_attention_norm_scale, + activations.pre_att_rms_out); + + Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, + div_seq_len, layer_idx, layer, activations, kv_caches); + + PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, + activations.att_sums); + + ResidualConnection(activations.att_sums, activations.x, layer, + /*is_attention=*/true); + + RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, + activations.pre_ffw_rms_out); + + if (layer_config.type == LayerAttentionType::kVit) { + FFWVit(activations, layer); + } else { + FFWNoVit(activations, layer); + } + + PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, + activations.ffw_out); + + ResidualConnection(activations.ffw_out, activations.x, layer, + /*is_attention=*/false); +} + +// Returns the scale value to use for the embedding (basically sqrt model_dim). +static float EmbeddingScaling(size_t model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo( + hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); +} + +// `batch_idx` indicates which row of `x` to write to. +// `pos` is the *token*'s position, not the start of the batch, because this is +// called for batches of tokens in prefill, but batches of queries in decode. +// +// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 +// spec) until we run out of image tokens. This allows for a multi-image prompt +// if -2 locations with appropriate begin/end image tokens are created by the +// calling application. +// Returns new image_token_position. +static HWY_NOINLINE size_t +EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, + const ModelConfig& model_config, const ModelWeightsPtrs& weights, + MatStorageT& x, const ImageTokens* image_tokens = nullptr, + size_t image_token_position = 0) { + // Image tokens just need to be copied. + if (model_config.wrapping == PromptWrapping::GEMMA_VLM && + image_tokens != nullptr && token == -2 && + image_token_position < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), + x.Cols() * x.ElementBytes()); + return image_token_position + 1; + } + + if (model_config.wrapping == PromptWrapping::PALIGEMMA && + image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), + x.Cols() * x.ElementBytes()); + return image_token_position; + } + + const size_t model_dim = model_config.model_dim; + const float emb_scaling = EmbeddingScaling(model_dim); + + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < static_cast(model_config.vocab_size)); + + CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) { + // Using `Stride` to compute the offset works for both NUQ (because we use + // an offset and NUQ is never padded) and padded, because non-NUQ types are + // seekable, hence the offset can also skip any padding. + const size_t embedding_ofs = token * weights_t->Stride(); + HWY_ASSERT(weights_t->Cols() == model_dim); + const auto embedding_span = + MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); + const hn::ScalableTag df; + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), + model_dim); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim); + }); + + if (model_config.absolute_pe) { + AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); + } + return image_token_position; +} + +// Prefill() and Transformer() increment positions in-place. +using QueriesMutablePos = hwy::Span; + +// Populates KV cache for batches of tokens from one query at a time. +static HWY_NOINLINE void Prefill( + const QueriesPromptTokens& queries_prompt, + const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, + const size_t query_idx_start, const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, + const KVCaches& kv_caches) { + PROFILER_ZONE("Gen.Prefill"); + const size_t num_queries = queries_prompt.size(); + HWY_DASSERT(queries_pos.size() == num_queries); + HWY_DASSERT(queries_prefix_end.size() == num_queries); + HWY_DASSERT(kv_caches.size() == num_queries); + + // Batches are important for amortizing loading weights over multiple tokens. + // This is possible in prefill because we know all tokens beforehand, whereas + // decode depends on the previous output token. However, each prefill batch of + // a query requires that preceding batches already wrote to the KV cache, + // hence we sequentially loop over token batches. We can reduce the number of + // iterations by increasing the batch size, but this also increases arithmetic + // intensity, and so we are eventually compute-limited. We could devote some + // threads to parallelizing over queries, but for simplicity we assign them + // all to MatMul. + const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; + + // For each query. `qi` is within the batch, not the global query index. + for (size_t qi = 0; qi < num_queries; ++qi) { + // Single query at a time, so pass slices of the spans because + // GemmaAttention will only access the first KV cache and position. + QueriesPos single_query_pos(&queries_pos[qi], 1); + QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1); + KVCaches single_kv_cache(&kv_caches[qi], 1); + + const size_t prompt_size = queries_prompt[qi].size(); + // In autoregressive mode, we don't need to prefill the last token, so - 1. + size_t prefill_this_query = prompt_size - 1; + const size_t prefix_end_this_query = queries_prefix_end[qi]; + // We can't attend beyond the prompt_size. + HWY_ASSERT(prefix_end_this_query <= prompt_size); + // Special case: if the prefix includes the last token, we need to prefill + // the last token, too. However, we need to rewind this for the generation + // of the first token. So we need to keep track of this. + // TODO: consider implementing masking instead of this logic? + const bool attend_to_last_token = + (prefill_this_query < prefix_end_this_query); + if (attend_to_last_token) { + // The difference can be at most 1. + prefill_this_query += 1; + HWY_ASSERT(prefill_this_query == prefix_end_this_query); + } + // In prefix-LM mode, we need to look at all the tokens for the prefix in + // one iteration through the layers, so we need a large enough batch size. + HWY_ASSERT(prefix_end_this_query == 0 || + max_tbatch_size >= prefill_this_query); + + // For each batch of tokens in the query: + for (size_t tbatch_start = 0; tbatch_start < prefill_this_query; + tbatch_start += max_tbatch_size) { + const size_t tbatch_size = + HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); + activations.SetBatchSize(tbatch_size); + + // Fill activations.x (much faster than TransformerLayer). + size_t image_token_position = 0; + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const size_t pos = queries_pos[qi] + ti; + const size_t pos_in_prompt = tbatch_start + ti; + const int token = queries_prompt[qi][pos_in_prompt]; + image_token_position = EmbedMMToken( + token, ti, pos, pos_in_prompt, config, weights, activations.x, + runtime_config.image_tokens, image_token_position); + } + + // Transformer with one batch of tokens from a single query. + for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); + ++layer_idx) { + TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, + div_seq_len, layer_idx, *weights.GetLayer(layer_idx), + activations, single_kv_cache); + } + + // NOTE: we unconditionally call StreamToken, even if EOS. + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const size_t pos = queries_pos[qi] + ti; + const size_t pos_in_prompt = tbatch_start + ti; + const int token = queries_prompt[qi][pos_in_prompt]; + if (pos_in_prompt < prompt_size - 1) { + runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); + } else { + // The last token will be streamed later and we should only get here + // if we need to attend to the last token because it is in the prefix. + HWY_ASSERT(attend_to_last_token); + } + } + + queries_pos[qi] += tbatch_size; + } // for tbatch_start + if (attend_to_last_token) { + // We need to rewind the position for the last token that we only + // attended to to make sure the prefix LM sees everything. + // This means we duplicate work on the last prompt token in autoregressive + // decoding. Alternatives: (1) real masking; (2) always prefill the last + // token and only generate the next one from the already prefilled + // activations. + queries_pos[qi] -= 1; + } + } +} + +// Generates one token for each query. `queries_token` is the previous token +// from each query, and `queries_pos` are their position in the sequence. +static HWY_NOINLINE void Transformer( + const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, + const QueriesPos& queries_prefix_end, const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + const LayersOutputFunc& layers_output, + const ActivationsObserverFunc& activations_observer) { + const size_t num_queries = queries_token.size(); + HWY_DASSERT(queries_pos.size() == num_queries); + HWY_DASSERT(queries_prefix_end.size() == num_queries); + + if (layers_output) { + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + const float token_f = queries_token[query_idx]; + layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f, + 1); + } + } + + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], + /*pos_in_prompt=*/0, config, weights, activations.x); + } + + for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { + TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, + div_seq_len, layer_idx, *weights.GetLayer(layer_idx), + activations, kv_caches); + + if (activations_observer) { + activations_observer(queries_pos, layer_idx, activations); + } + } + + RMSNormInplaceBatched(weights.final_norm_scale, activations.x); + + if (activations_observer) { + activations_observer(queries_pos, -1, activations); + } + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + queries_pos[query_idx] += 1; + } +} + +void RangeChecks(const ModelConfig& weights_config, + size_t& max_generated_tokens, const size_t prompt_size) { + if (!weights_config.use_local_attention) { + if (max_generated_tokens > weights_config.seq_len) { + HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.", + max_generated_tokens, weights_config.seq_len); + max_generated_tokens = weights_config.seq_len; + } + } + HWY_ASSERT(prompt_size > 0); +} + +// Holds "is at end of stream" state for each query. +class TokenStreamer { + public: + TokenStreamer(const RuntimeConfig& runtime_config, + const ModelConfig& model_config) + : runtime_config_(runtime_config), model_config_(model_config) {} + + // Returns whether the query was already at, or has just reached, the end of + // the stream: either via token == eos_id, or StreamToken returning false. + bool operator()(size_t query_idx, size_t pos, int token, float prob) { + if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; + + if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || + model_config_.IsEOS(token)) { + is_eos_.Set(query_idx); + return true; + } + + return false; + } + + private: + const RuntimeConfig& runtime_config_; + const ModelConfig& model_config_; + hwy::BitSet4096<> is_eos_; +}; + +// Runs one decode step for all the queries in the batch. Returns true if all +// queries are at . +static bool DecodeStepT(const ModelConfig& config, + const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const size_t query_idx_start, const KVCaches& kv_caches, + const QueriesPos& queries_prefix_end, + const hwy::Divisor div_seq_len, const size_t vocab_size, + const SampleFunc& sample_token, + Activations& activations, TokenStreamer& token_streamer, + std::vector& gen_tokens, TimingInfo& timing_info, + const QueriesMutablePos& queries_mutable_pos) { + const size_t num_queries = queries_prompt.size(); + // Decode generates one token per query and increments + // queries_mutable_pos. + Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, + queries_prefix_end, config, weights, activations, div_seq_len, + kv_caches, runtime_config.layers_output, + runtime_config.activations_observer); + // queries_pos are incremented by Transformer. + + HWY_DASSERT(num_queries == activations.x.Rows()); + bool all_queries_eos = true; + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + CallMatMul(activations.x, weights.embedder_input_embedding, + /*add=*/nullptr, *activations.env, activations.logits); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + float* HWY_RESTRICT logits = activations.logits.Row(query_idx); + MaybeLogitsSoftCap(config.final_cap, logits, vocab_size); + const TokenAndProb tp = sample_token(logits, vocab_size); + timing_info.NotifyGenerated(); + + const bool is_eos = + token_streamer(query_idx_start + query_idx, + queries_mutable_pos[query_idx], tp.token, tp.prob); + all_queries_eos &= is_eos; + gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token; + } + return all_queries_eos; +} + +static HWY_INLINE SampleFunc +ChooseSampleFunc(const RuntimeConfig& runtime_config) { + // If user provided a sample_func, use it. + if (runtime_config.sample_func) return runtime_config.sample_func; + + // Fast path for top-1 with no accept_token. + if (runtime_config.top_k == 1 && !runtime_config.accept_token) { + return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE("Gen.Sample Top1"); + return Top1OfSoftmax(logits, vocab_size); + }; + } + + // General case: Softmax with top-k sampling. + return [&runtime_config](float* logits, + size_t vocab_size) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE("Gen.Sample general"); + return FusedSoftmaxAndSampleTopK( + logits, runtime_config.top_k, vocab_size, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); + }; +} + +// Returns the min and max number of tokens for all queries. +static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { + size_t max_prompt_size = 0; + for (size_t i = 0; i < queries_prompt.size(); ++i) { + max_prompt_size = HWY_MAX(max_prompt_size, queries_prompt[i].size()); + } + return max_prompt_size; +} + +// Generates one continuation for each query in `queries_prompt`, which is one +// qbatch whose size is at most the `batch_size` passed to +// `activations.Allocate`. +// +// `queries_pos` stores the KV cache position for each query. In the first turn +// of a chat, pos = 0; we increment each query's position after each token. +// +// `query_idx_start` is the query_idx of the first query in the batch, so that +// `StreamFunc` gets the global query index, not relative to the batch. +// +// `kv_caches` is for the batch, size must match `queries_prompt`. +static void GenerateT(const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const QueriesPos& queries_pos_in, + const QueriesPos& queries_prefix_end, + const size_t query_idx_start, const KVCaches& kv_caches, + TimingInfo& timing_info) { + HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); + + // Griffin assumes that the recurrent block cache is zero-initialized. + for (size_t i = 0; i < kv_caches.size(); ++i) { + if (queries_pos_in[i] == 0) { + kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. + } + } + + // Copy so we can increment without requiring users to pass in a mutable span. + std::vector queries_pos_copy(queries_pos_in.cbegin(), + queries_pos_in.cend()); + const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), + queries_pos_copy.size()); + + // Sanity check: prompts should not be empty, nor start with EOS. + for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { + const PromptTokens& prompt = queries_prompt[query_idx]; + HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); + } + + const size_t num_queries = queries_prompt.size(); + HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. + HWY_ASSERT(num_queries <= activations.x.Rows()); + HWY_ASSERT(queries_pos_in.size() == num_queries); + HWY_ASSERT(kv_caches.size() == num_queries); + const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); + size_t max_prompt_size = MaxQueryLength(queries_prompt); + size_t max_generated_tokens = runtime_config.max_generated_tokens; + RangeChecks(config, max_generated_tokens, max_prompt_size); + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); + + // Prefill stops before min_prompt_size - 1 because the last prompt + // token is the first input token for generation. + timing_info.prefill_start = hwy::platform::Now(); + // Note that Prefill calls activations.SetBatchSize, so we reset it below. + Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, + query_idx_start, config, weights, activations, runtime_config, + div_seq_len, kv_caches); + // Compute the number of tokens that were prefilled and notify timing_info. + size_t prefilled_tokens = 0; + for (size_t qi = 0; qi < num_queries; ++qi) { + prefilled_tokens += queries_prompt[qi].size() - 1; + } + timing_info.NotifyPrefill(prefilled_tokens); + // queries_pos are incremented by Prefill. + activations.SetBatchSize(num_queries); + + // Storage for the last generated token from each query, passed to the next + // Transformer() call. + std::vector gen_tokens(num_queries); + + // Stream the last prompt token from each query and fill gen_tokens. + TokenStreamer token_streamer(runtime_config, config); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + size_t last_token_pos_in_prompt = + queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; + gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt]; + (void)token_streamer(query_idx_start + query_idx, + queries_mutable_pos[query_idx], gen_tokens[query_idx], + 0.0f); + } + + { + const size_t vocab_size = config.vocab_size; + timing_info.generate_start = hwy::platform::Now(); + for (size_t gen = 0; gen < max_generated_tokens; ++gen) { + bool all_queries_eos = DecodeStepT( + config, weights, runtime_config, queries_prompt, query_idx_start, + kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, + activations, token_streamer, gen_tokens, timing_info, + queries_mutable_pos); + if (all_queries_eos) break; + } // foreach token to generate + timing_info.NotifyGenerateDone(); + } +} + +void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const PromptTokens& prompt, size_t pos, size_t prefix_end, + KVCache& kv_cache, MatMulEnv* env, + TimingInfo& timing_info) { + constexpr size_t kNumQueries = 1; + const size_t qbatch_start = 0; + + const size_t max_batch_size = + HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); + // TODO: move into Gemma? + Activations activations(config, max_batch_size, env); + + const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); + QueriesPos queries_pos(&pos, kNumQueries); + const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); + const KVCaches kv_caches{&kv_cache, kNumQueries}; + + GenerateT(config, weights, activations, runtime_config, queries_prompt, + queries_pos, queries_prefix_end, qbatch_start, kv_caches, + timing_info); +} + +void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, + const KVCaches& kv_caches, MatMulEnv* env, + TimingInfo& timing_info) { + const size_t num_queries = queries_prompt.size(); + HWY_ASSERT(queries_pos.size() == num_queries); + HWY_ASSERT(kv_caches.size() >= num_queries); + const size_t max_qbatch_size = runtime_config.decode_qbatch_size; + const size_t max_batch_size = + HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); + + Activations activations(config, max_batch_size, env); + + for (size_t qbatch_start = 0; qbatch_start < num_queries; + qbatch_start += max_qbatch_size) { + // Generate one batch of tokens from `qbatch_size` queries. + const size_t qbatch_size = + HWY_MIN(num_queries - qbatch_start, max_qbatch_size); + const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], + qbatch_size); + QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); + const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], + qbatch_size); + const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); + GenerateT(config, weights, activations, runtime_config, qbatch_prompts, + qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, + timing_info); + } +} + +void GenerateImageTokensT(const ModelConfig& config, + const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const Image& image, ImageTokens& image_tokens, + MatMulEnv* env) { + if (config.vit_config.layer_configs.empty()) { + HWY_ABORT("Model does not support generating image tokens."); + } + RuntimeConfig prefill_runtime_config = runtime_config; + ModelConfig vit_config = GetVitConfig(config); + prefill_runtime_config.prefill_tbatch_size = + vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); + Activations prefill_activations(vit_config, vit_config.seq_len, env); + // Weights are for the full PaliGemma model, not just the ViT part. + PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, + prefill_activations); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + #if HWY_ONCE namespace gcpp { HWY_EXPORT(GenerateSingleT); diff --git a/gemma/gemma.h b/gemma/gemma.h index 18018c8..99936f5 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -160,9 +160,6 @@ class Gemma { GemmaChatTemplate chat_template_; }; -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, size_t prompt_size); - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma/griffin.cc b/gemma/griffin.cc new file mode 100644 index 0000000..46606b1 --- /dev/null +++ b/gemma/griffin.cc @@ -0,0 +1,193 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/weights.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "ops/matvec-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Different functions use different naming conventions for the number of +// tokens. Functions that are query-independent, such as RMSNorm*, call the +// count `num_interleaved`. Functions that are query-dependent, such as +// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the +// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. + +void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, + size_t griffin_layer, Activations& activations, + const LayerWeightsPtrs* layer_weights, + const KVCaches& kv_caches) { + PROFILER_ZONE("Gen.Griffin"); + hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D df; + + const size_t model_dim = layer_weights->layer_config.model_dim; + HWY_DASSERT(model_dim % hn::Lanes(df) == 0); + + const size_t heads = layer_weights->layer_config.heads; + const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; + HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); + const size_t kHeadDim = model_dim / heads; + const size_t kMatrixSize = kHeadDim * kHeadDim; + + const size_t num_queries = queries_pos.size(); + const hwy::Divisor div_num_q(static_cast(num_queries)); + const size_t num_interleaved = num_tokens * num_queries; + + // X / Y linear layers. + // TODO: MatMul + HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); + HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); + CallUpcastedSame( + &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, + [&](const auto* wx, const auto* wy) { + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT y = activations.griffin_y.Row(r); + float* HWY_RESTRICT x = activations.griffin_x.Row(r); + TwoMatVecAdd( + *wx, *wy, 0, model_dim, model_dim, + activations.pre_att_rms_out.Row(r), + /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), + /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), + /*out0=*/x, /*out1=*/y, pool); + Gelu(y, model_dim); + } + }); + + // Conv1D. + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = div_num_q.Remainder(interleaved_idx); + const size_t batch_idx = div_num_q.Divide(interleaved_idx); + const size_t pos = queries_pos[query_idx] + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); + + // cache[i] = input at time t-i. + float* HWY_RESTRICT cache[kMaxConv1DWidth]; + cache[0] = x; + for (size_t i = 1; i < conv_1d_width; i++) { + cache[i] = + kv_caches[query_idx].conv1d_cache.Row(griffin_layer) + + ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; + } + for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { + auto xv = hn::Load(df, x + i); + auto accum0 = + hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); + auto accum1 = hn::Zero(df); + for (size_t l = 0; 2 * l < conv_1d_width; l++) { + auto wv0 = + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + + (conv_1d_width - 1 - 2 * l) * model_dim + i); + auto wv1 = + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + + (conv_1d_width - 2 - 2 * l) * model_dim + i); + accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); + accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); + } + hn::Store(hn::Add(accum0, accum1), df, x + i); + hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); + } + } + + // RGLRU + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = div_num_q.Remainder(interleaved_idx); + const size_t batch_idx = div_num_q.Divide(interleaved_idx); + const size_t pos = queries_pos[query_idx] + batch_idx; + + float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); + float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx); + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx); + float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx); + float* HWY_RESTRICT rnn_state = + kv_caches[query_idx].rglru_cache.Row(griffin_layer); + + pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + size_t head_offset = head * kHeadDim; + CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { + TwoOfsMatVecAddLoop( + *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, + kHeadDim, x + head_offset, + /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + + model_dim + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + }); + Sigmoid(gate_x + head_offset, kHeadDim); + Sigmoid(a + head_offset, kHeadDim); + const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) + HWY_ATTR { return hn::Mul(x, gate_x); }; + hn::Transform1(D(), a + head_offset, kHeadDim, + layer_weights->griffin.a.PackedScale1() + head_offset, + fn_mul); + hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, + fn_mul); + // RNN scan + HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { + auto log_a = hn::Load(df, a + head_offset + i); + auto gated_x = hn::Load(df, x + head_offset + i); + auto rnn = hn::Load(df, rnn_state + head_offset + i); + auto a = hn::Exp(df, log_a); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); + if (pos == 0) { + x_multiplier = hn::Set(df, 1.0f); + } + auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); + hn::Store(new_x, df, rnn_state + head_offset + i); + + // Join branches. + auto yv = hn::Load(df, y + head_offset + i); + auto pre_out = hn::Mul(yv, new_x); + hn::Store(pre_out, df, x + head_offset + i); + } + }); + } // interleaved_idx + + // Final linear layer. + CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, + layer_weights->griffin.linear_out_biases.PackedScale1(), + *activations.env, activations.att_sums); +} // GriffinRecurrent + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/griffin.h b/gemma/griffin.h new file mode 100644 index 0000000..77011a3 --- /dev/null +++ b/gemma/griffin.h @@ -0,0 +1,47 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ + +// Declares GriffinRecurrent for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \ + size_t griffin_layer, Activations& activations, \ + const LayerWeightsPtrs* layer_weights, \ + const KVCaches& kv_caches); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN) + +#undef GEMMA_DECL_GRIFFIN + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ diff --git a/gemma/vit.cc b/gemma/vit.cc new file mode 100644 index 0000000..f9c50ff --- /dev/null +++ b/gemma/vit.cc @@ -0,0 +1,339 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // sqrtf +#include +#include + +#include + +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/weights.h" +#include "paligemma/image.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/vit.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/gemma-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Wrapper class; holds arguments in member variables to shorten call sites. +// The main differences to GemmaAttention are: +// - no KV Cache necessary, attention is always all-to-all and not causal. +// - no potential wrap-around, attention always goes from 0 to kSeqLen. +// - no need for batching, as we are always computing attention for kSeqLen +// tokens. +// This results in a much simpler implementation. However, to avoid duplicating +// code, we should still consider merging the two classes. +// TODO(keysers): Refactor to share code with GemmaAttention. +class VitAttention { + // Computes Q, K, V for all heads, stored in activations_.q. + HWY_NOINLINE void ComputeQKV() { + PROFILER_ZONE("Gen.VitAttention.QKV"); + auto& qkv = activations_.q; + HWY_ASSERT(qkv.Rows() == num_tokens_); + HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); + CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w, + layer_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, qkv); + } + + // TODO(philculliton): transition fully to MatMul. + HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = activations_.seq_len; + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); + PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); + + // Shift Q, K, VT to MatStorageT. + MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), + MatPadding::kPacked); + MatStorageT K("K2", Extents2D(seq_len, qkv_dim), + MatPadding::kPacked); + MatStorageT C("C2", Extents2D(num_tokens_, seq_len), + MatPadding::kPacked); + + // Initialize att_out to zero prior to head loop. + ZeroInit(activations_.att_out); + + for (size_t head = 0; head < heads; ++head) { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t token = task; + float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim; + // TODO: shift to MatMul with A.scale once MatMul is confirmed working + MulByConst(query_scale, q, qkv_dim); + hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); + }); + + pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t seq_idx = task; + float* HWY_RESTRICT k = + activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim; + hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); + }); + + // this produces C, a (num_tokens_, seq_len) matrix of dot products + CallMatMul(Q, K, nullptr, *activations_.env, C); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + float* HWY_RESTRICT c = C.Row(task); + Softmax(c, C.Cols()); + }); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + size_t token = task; + float* HWY_RESTRICT att_out = + activations_.att_out.Row(token) + head * qkv_dim; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = + activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); + } + }); + } + } + + HWY_NOINLINE void DotSoftmaxWeightedSum() { + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = activations_.seq_len; + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); + PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); + + // Compute Q.K, softmax, and weighted V. + pool_.Run(0, layer_config_.heads * num_tokens_, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % layer_config_.heads; + const size_t token = task / layer_config_.heads; + // Compute Q.K scores, which are "logits" stored in head_att. + float* HWY_RESTRICT q = + activations_.q.Row(token) + head * 3 * qkv_dim; + MulByConst(query_scale, q, qkv_dim); + float* HWY_RESTRICT head_att = + activations_.att.Row(token) + head * activations_.seq_len; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT k = + activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; + head_att[i] = Dot(q, k, qkv_dim); // score = q.k + } + // SoftMax yields "probabilities" in head_att. + Softmax(head_att, seq_len); + // Compute weighted sum of v into att_out. + float* HWY_RESTRICT att_out = + activations_.att_out.Row(token) + head * qkv_dim; + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = + activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); + } + }); + } + + // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and + // head_dim (`qkv_dim`) into output (`att_sums`). + HWY_NOINLINE void SumHeads() { + PROFILER_ZONE("Gen.VitAttention.SumHeads"); + auto* bias = layer_.vit.attn_out_b.PackedScale1(); + // att_weights and att_out are concatenated heads, each of length + // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] + // matmul output is the sum over heads. + CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, + *activations_.env, activations_.att_sums); + } + + public: + VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations, + const LayerWeightsPtrs& layer) + : num_tokens_(num_tokens), + activations_(activations), + layer_(layer), + layer_config_(layer.layer_config), + pool_(activations.env->ctx.pools.Pool(0)) {} + + HWY_INLINE void operator()() { + ComputeQKV(); + if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + DotSoftmaxWeightedSumMatrix(); + } else { + DotSoftmaxWeightedSum(); + } + SumHeads(); + } + + private: + const size_t num_tokens_; + Activations& activations_; + const LayerWeightsPtrs& layer_; + const LayerConfig& layer_config_; + hwy::ThreadPool& pool_; +}; + +// Same as FFWNoVit, but with different layer members and no second +// gating matrix. +void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) { + PROFILER_ZONE("Gen.FFW.ViT"); + const LayerConfig& layer_config = layer.layer_config; + + const bool add_bias = layer_config.ff_biases; + const float* bias1 = add_bias ? layer.vit.linear_0_b.PackedScale1() : nullptr; + const float* output_bias = + add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr; + + // Compute the hidden layer activations. + CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, + *activations.env, activations.C1); + + // Activation (Gelu), store in C1. + ActivationBatched(layer_config.activation, activations.C1); + + // Hidden layer -> output layer. + CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, + *activations.env, activations.ffw_out); +} + +// Vit transformer layer. Some comments below refer to the Vit implementation in +// the Big Vision codebase. See +// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py +// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and +// try merging this with TransformerLayer. +void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations) { + const size_t model_dim = activations.weights_config.model_dim; + auto type = layer.layer_config.type; + HWY_DASSERT(type == LayerAttentionType::kVit); + (void)type; + (void)model_dim; + + auto& x = activations.x; + HWY_DASSERT(x.Rows() == num_tokens); + HWY_DASSERT(x.Cols() == model_dim); + + // y = nn.LayerNorm()(x) + // y ~ pre_att_rms_out + LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias, + activations.pre_att_rms_out); + + // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) + // y ~ att_sums + VitAttention(num_tokens, layer_idx, activations, layer)(); + + // x = out["+sa"] = x + y + AddFromBatched(activations.att_sums, x); + + // y = nn.LayerNorm()(x) + // y ~ pre_ffw_rms_out + LayerNormBatched(x, layer.vit.layer_norm_1_scale, layer.vit.layer_norm_1_bias, + activations.pre_ffw_rms_out); + + // y = out["mlp"] = MlpBlock(...)(y) + // y ~ ffw_out + FFWVit(activations, layer); + + // x = out["+mlp"] = x + y + AddFromBatched(activations.ffw_out, x); +} + +// Gets the patches of the image and embeds them with the image embedding +// kernel. The result is stored in activations.x. +static HWY_NOINLINE void EmbedImagePatches(const Image& image, + const ModelConfig& model_config, + const ModelWeightsPtrs& weights, + Activations& activations) { + const size_t model_dim = model_config.vit_config.model_dim; + const size_t patch_width = model_config.vit_config.patch_width; + const size_t seq_len = model_config.vit_config.seq_len; + const size_t patch_size = patch_width * patch_width * 3; + HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); + HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); + HWY_DASSERT(activations.x.Cols() == model_dim); + (void)model_dim; + // img/embedding/kernel has original shape (14, 14, 3, 1152) + // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) + // image_patches is (256, 14 * 14 * 3) + // Must be padded, see `DoDecompressA`. + MatStorageT image_patches("patches", Extents2D(seq_len, patch_size), + MatPadding::kOdd); + for (size_t i = 0; i < seq_len; ++i) { + image.GetPatch(i, image_patches.Row(i)); + } + CallMatMul(image_patches, weights.vit_img_embedding_kernel, + weights.vit_img_embedding_bias.PackedScale1(), *activations.env, + activations.x); + // Add position embeddings. + CallUpcastedActivation(&weights.vit_img_pos_embedding, + [&](const auto* weights_t) { + AddFromBatched(*weights_t, activations.x); + }); +} + +// Prefills the image tokens with the ViT encoder. +void PrefillVit(const ModelConfig& model_config, + const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, const Image& image, + ImageTokens& image_tokens, Activations& activations) { + PROFILER_ZONE("Gen.PrefillVit"); + const size_t num_tokens = model_config.vit_config.seq_len; + const size_t vit_model_dim = model_config.vit_config.model_dim; + HWY_ASSERT(num_tokens == activations.x.Rows()); + // Embed the image patches. + EmbedImagePatches(image, model_config, weights, activations); + // Go through all layers. + for (size_t layer_idx = 0; + layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) { + VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx), + activations); + } + // Final Layernorm. + LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, + weights.vit_encoder_norm_bias, activations.x); + + if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { + activations.x = AvgPool4x4(activations.x); + + // Apply soft embedding norm before input projection. + CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), + vit_model_dim); + }); + } + + // Apply head embedding into image_tokens of size of the LLM kModelDim. + CallMatMul(activations.x, weights.vit_img_head_kernel, + weights.vit_img_head_bias.PackedScale1(), *activations.env, + image_tokens); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/vit.h b/gemma/vit.h new file mode 100644 index 0000000..085081d --- /dev/null +++ b/gemma/vit.h @@ -0,0 +1,49 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ + +// Declares vision transformer FFW/Prefill for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void FFWVit(Activations& activations, const LayerWeightsPtrs& layer); \ + \ + void PrefillVit(const ModelConfig& model_config, \ + const ModelWeightsPtrs& weights, \ + const RuntimeConfig& runtime_config, const Image& image, \ + ImageTokens& image_tokens, Activations& activations); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_VIT) + +#undef GEMMA_DECL_VIT + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ diff --git a/gemma/weights.cc b/gemma/weights.cc index d8f1491..c2e47f2 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -448,8 +448,10 @@ static std::vector MakeBatches( HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); } offset += file_bytes_per_row; + // Must zero-initialize the in-memory row padding, see MatMul. + hwy::ZeroBytes(row_bytes + file_bytes_per_row, + mem_stride_bytes - file_bytes_per_row); row_bytes += mem_stride_bytes; - // Keep the in-memory row padding uninitialized so msan detects any use. } HWY_ASSERT(offset == range.End()); } diff --git a/gemma/weights.h b/gemma/weights.h index ac26340..250450a 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -93,6 +93,8 @@ class MatFinder { // `WeightsOwner`. struct LayerWeightsPtrs { // Initializes tensor metadata without allocating. + // NOTE: do not store layer_idx, TransformerLayer and Attention may use + // other values for purposes of the KV cache. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) : finder_(LayerSuffix(layer_idx), tensors), diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 9be6f4c..9454b66 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1294,6 +1294,10 @@ struct MMImpl { // are no other restrictions on shape, though performance is better when `M % 4 // == 0` or `M <= 4`. // +// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(), +// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match +// the behavior of `DecompressAndZeroPad`. We check this in debug builds. +// // If `add` is non-null, the row-vector `add` is added to each of the `M` rows // of `C`, which is a row-major matrix with arbitrary stride. A scale for // `add` is not supported, so make sure its scale is 1. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 219c006..c7e5758 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -27,6 +27,7 @@ #include // std::enable_if_t #include +#include "ops/matmul.h" #include "util/allocator.h" #include "util/basics.h" // TokenAndProb #include "util/mat.h" @@ -48,6 +49,7 @@ #include "compression/compress-inl.h" #include "ops/dot-inl.h" +#include "ops/matmul_static.h" // includes highway.h #include "ops/sum-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/math/math-inl.h" @@ -57,6 +59,14 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +template +MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + MatPtrT& C) { + return CallUpcasted( + &B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); }); +} + HWY_INLINE double PackTokenAndProb(int32_t token, float prob) { // casting prob from float to double just makes some changes to the // exponent bias and pads zeros in the mantissa. diff --git a/ops/ops_test.cc b/ops/ops_test.cc index e6c2d3e..67890ec 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -31,7 +31,7 @@ #include #include -#include "gemma/common.h" // ChooseQueryScale +#include "gemma/activations.h" // ChooseQueryScale #include "util/allocator.h" #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT diff --git a/util/mat.cc b/util/mat.cc index 3344cad..44d62ec 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -54,10 +54,9 @@ void ZeroInit(MatPtr& mat) { hwy::ZeroBytes(mat.Packed(), mat.PackedBytes()); return; } - const size_t row_bytes = mat.Cols() * mat.ElementBytes(); - for (size_t r = 0; r < mat.Rows(); ++r) { - hwy::ZeroBytes(mat.RowBytes(r), row_bytes); - } + // Also zero-initialize padding (required by MatMul). + hwy::ZeroBytes(mat.RowBytes(0), + mat.Stride() * mat.ElementBytes() * mat.Rows()); } size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,