Split gemma-inl into separate source files

weights, mat: zero-initialize padding, required since the MatMul "avoid B decompress" optimization.

PiperOrigin-RevId: 767562313
This commit is contained in:
Jan Wassenberg 2025-06-05 05:36:08 -07:00 committed by Copybara-Service
parent dd7d4a7717
commit 3a266c662c
21 changed files with 1736 additions and 1466 deletions

View File

@ -216,8 +216,8 @@ cc_library(
":configs",
":gemma_args",
":mat",
":matmul",
":model_store",
":ops",
":tensor_info",
":threading_context",
"//compression:compress",
@ -246,11 +246,7 @@ cc_library(
name = "common",
srcs = ["gemma/common.cc"],
hdrs = ["gemma/common.h"],
deps = [
":basics",
":configs",
"@highway//:hwy", # base.h
],
deps = [":configs"],
)
# For building all tests in one command, so we can test several.
@ -260,42 +256,25 @@ test_suite(
)
cc_library(
name = "ops",
srcs = [
"ops/matmul.cc",
],
hdrs = [
"ops/matmul.h",
"ops/ops.h",
],
textual_hdrs = [
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matmul-inl.h",
"ops/matvec-inl.h",
"ops/ops-inl.h",
],
name = "matmul",
srcs = ["ops/matmul.cc"],
hdrs = ["ops/matmul.h"],
textual_hdrs = ["ops/matmul-inl.h"],
deps = [
":allocator",
":basics",
":mat",
":threading_context",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
cc_library(
name = "matmul",
name = "matmul_static",
srcs = [
# single-file build time is ~30sec for msan, hence shard.
"ops/matmul_static_bf16.cc",
@ -313,7 +292,7 @@ cc_library(
deps = [
":allocator",
":basics",
":ops",
":matmul",
":threading_context",
"//compression:compress",
"//compression:types",
@ -323,6 +302,34 @@ cc_library(
],
)
cc_library(
name = "ops",
hdrs = ["ops/ops.h"],
textual_hdrs = [
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matvec-inl.h",
"ops/ops-inl.h",
],
deps = [
":allocator",
":basics",
":mat",
":matmul",
":matmul_static",
":threading_context",
"//compression:compress",
"@highway//:algo",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
cc_test(
name = "dot_test",
size = "small",
@ -359,7 +366,7 @@ cc_test(
deps = [
":allocator",
":basics",
":common",
":gemma_lib",
":mat",
":ops",
":test_util",
@ -402,6 +409,7 @@ cc_test(
":basics",
":mat",
":matmul",
":matmul_static",
":ops",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
@ -458,7 +466,7 @@ cc_library(
":args",
":basics",
":mat",
":ops", # matmul.h
":matmul",
"//io",
"@highway//:hwy",
],
@ -467,11 +475,17 @@ cc_library(
cc_library(
name = "gemma_lib",
srcs = [
"gemma/attention.cc",
"gemma/gemma.cc",
"gemma/griffin.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/attention.h",
"gemma/gemma.h",
"gemma/griffin.h",
"gemma/vit.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
@ -479,12 +493,10 @@ cc_library(
},
textual_hdrs = [
"gemma/gemma-inl.h",
# Placeholder for internal file2, do not remove,
],
deps = [
":allocator",
":basics",
":common",
":configs",
":gemma_args",
":kv_cache",
@ -527,6 +539,7 @@ cc_library(
":cross_entropy",
":gemma_args",
":gemma_lib",
":matmul",
":ops",
":threading_context",
":tokenizer",
@ -616,7 +629,7 @@ cc_binary(
":benchmark_helper",
":gemma_args",
":gemma_lib",
":ops",
":matmul",
":tokenizer",
"//compression:types",
"//paligemma:image",

View File

@ -53,6 +53,8 @@ set(SOURCES
evals/cross_entropy.cc
evals/cross_entropy.h
gemma/activations.h
gemma/attention.cc
gemma/attention.h
gemma/common.cc
gemma/common.h
gemma/configs.cc
@ -61,6 +63,8 @@ set(SOURCES
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/griffin.cc
gemma/griffin.h
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/model_store.cc
@ -69,6 +73,8 @@ set(SOURCES
gemma/tensor_info.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/vit.cc
gemma/vit.h
gemma/weights.cc
gemma/weights.h
io/blob_store.cc

View File

@ -12,7 +12,7 @@ cc_library(
deps = [
"//:gemma_args",
"//:gemma_lib",
"//:ops",
"//:matmul",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",

View File

@ -16,6 +16,7 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#include <math.h> // sqrtf
#include <stddef.h>
#include <vector>
@ -29,6 +30,16 @@
namespace gcpp {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env)
: weights_config(config),
@ -36,6 +47,7 @@ struct Activations {
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()),
is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
@ -129,6 +141,7 @@ struct Activations {
size_t seq_len;
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
bool is_griffin = false;
float query_scale;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;

346
gemma/attention.cc Normal file
View File

@ -0,0 +1,346 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<float>& k, float* HWY_RESTRICT att) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = Dot(q, k.Row(pos), k.Cols());
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
att[pos_modulo] = score;
}
}
}
template <typename U>
static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
const Activations& activations,
const size_t pos, const float mul = 1.0f) {
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
bool is_global_layer =
activations.weights_config.attention_window_sizes[layer_idx] ==
activations.seq_len;
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.weights_config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1();
}
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
}
}
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att,
const MatPtrT<float>& v,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = v.Cols();
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
}
}
}
// Calculates the attention outputs for a single q.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q,
const MatPtrT<float>& k, const MatPtrT<float>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const Activations& activations,
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const float att_cap = activations.weights_config.att_cap;
const float query_scale = activations.query_scale;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim);
});
}
PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos,
query_scale);
QDotK(start_pos, last_pos, div_seq_len, q, k, att);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len =
HWY_MIN(last_pos + 1, static_cast<size_t>(div_seq_len.GetDivisor()));
MaybeLogitsSoftCap(att_cap, att, att_len);
Softmax(att, att_len);
WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out);
}
// The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
size_t layer_idx) {
const size_t att_window_size = config.attention_window_sizes[layer_idx];
return pos - HWY_MIN(att_window_size - 1, pos);
}
void DotSoftmaxWeightedSum(
const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches, NestedPools& pools) {
const size_t num_queries = queries_pos.size();
const LayerConfig& layer_config = layer.layer_config;
PROFILER_ZONE("Gen.Attention.DotSoftmax");
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t cache_pos_size = activations.cache_pos_size;
// For each head (token, query), compute Q.K, softmax, and weighted V.
// TODO: nested parallelism to use more threads.
pools.Pool(0).Run(
0, layer_config.heads * num_tokens * num_queries,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config.heads;
const size_t interleaved_idx = task / layer_config.heads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q =
activations.q.Row(interleaved_idx) + head * qkv_dim;
float* HWY_RESTRICT att =
activations.att.Row(interleaved_idx) + head * activations.seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
// Make strided views into the kv cache entries for the current
// query and head.
KVCache& kv_cache = kv_caches[query_idx];
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
/*stride=*/cache_pos_size);
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
/*stride=*/cache_pos_size);
// Find the token position in the query and calculate the range
// of cache positions to attend to.
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end[query_idx];
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k,
v, layer_idx, layer, activations, att,
att_out);
});
}
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(
size_t num_tokens, const QueriesPos& queries_pos,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, const int flags, NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.QKV");
const size_t num_queries = queries_pos.size();
const size_t num_interleaved = num_tokens * num_queries;
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t cache_pos_size = activations.cache_pos_size;
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
/*add=*/nullptr, *activations.env, activations.q);
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<float> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t cache_pos =
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
const size_t kv_offset =
cache_pos * cache_pos_size + layer_idx * cache_layer_size;
activations.env->storage.OutRow(interleaved_idx) =
reinterpret_cast<uint8_t*>(kv_caches[query_idx].kv_cache.get() +
kv_offset);
}
kv_rows.AttachRowPtrs(&activations.env->storage.OutRow(0));
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, *activations.env, kv_rows);
// Apply positional encodings for K.
// TODO: 2D parallelism to use more threads.
pools.Pool(0).Run(
0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim);
});
}
PositionalEncodingQK(kv, qkv_dim, layer_idx, layer, activations, pos);
});
}
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
Activations& activations) {
PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config;
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
const float* add = layer_config.softmax_attn_output_biases
? layer.attention_output_biases.PackedScale1()
: nullptr;
CallMatMul(activations.att_out, layer.att_weights, add, *activations.env,
activations.att_sums);
}
// `queries_prefix_end` can be null (interpreted as all-zero) for standard
// causal attention, and must be non-null for prefix-LM style attention.
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos* queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, int flags) {
const size_t num_queries = queries_pos.size();
HWY_DASSERT(num_queries <= kv_caches.size());
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT
std::vector<size_t> queries_prefix_end_vec;
QueriesPos queries_prefix_end_span;
if (queries_prefix_end == nullptr) {
queries_prefix_end_vec.assign(num_queries, 0);
queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(),
queries_prefix_end_vec.size());
queries_prefix_end = &queries_prefix_end_span;
}
NestedPools& pools = activations.env->ctx.pools;
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
activations, kv_caches, flags, pools);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches,
pools);
SumHeads(layer, activations);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

63
gemma/attention.h Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
// Declares GemmaAttention for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \
const MatPtrT<float>& k, const MatPtrT<float>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const Activations& activations, \
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, \
const QueriesPos& queries_pos, \
const QueriesPos& queries_prefix_end, \
const hwy::Divisor& div_seq_len, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
Activations& activations, \
const KVCaches& kv_caches, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
const QueriesPos* queries_prefix_end, \
const hwy::Divisor& div_seq_len, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \
const KVCaches& kv_caches, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_ATTENTION)
#undef GEMMA_DECL_ATTENTION
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_

View File

@ -15,15 +15,11 @@
#include "gemma/common.h"
#include <math.h> // sqrtf
#include <stddef.h>
#include <string>
#include <vector>
#include "gemma/configs.h"
#include "util/basics.h" // BF16
#include "hwy/base.h" // ConvertScalarTo
namespace gcpp {
@ -39,30 +35,4 @@ void Wrap(const ModelConfig& config, size_t pos, std::string& prompt) {
}
}
float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
}
float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}
}
HWY_ASSERT(prompt_size > 0);
}
} // namespace gcpp

View File

@ -28,15 +28,6 @@ namespace gcpp {
// DEPRECATED, use WrapAndTokenize instead if a tokenized return value is fine.
void Wrap(const ModelConfig& config, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim).
float EmbeddingScaling(size_t model_dim);
// Returns the scale value to use for the query in the attention computation.
float ChooseQueryScale(const ModelConfig& config);
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

File diff suppressed because it is too large Load Diff

View File

@ -27,11 +27,15 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "gemma/griffin.h" // includes highway.h
#include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE
#define GEMMA_CC_ONCE
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
@ -49,10 +53,586 @@
#include "ops/matmul.h"
#include "paligemma/image.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/timer.h"
#endif // GEMMA_CC_ONCE
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, size_t num_tokens,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
layer_idx, layer, activations, kv_caches,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
&layer, kv_caches);
}
}
static HWY_NOINLINE void TransformerLayer(
const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches) {
const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out);
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums);
ResidualConnection(activations.att_sums, activations.x, layer,
/*is_attention=*/true);
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
activations.pre_ffw_rms_out);
if (layer_config.type == LayerAttentionType::kVit) {
FFWVit(activations, layer);
} else {
FFWNoVit(activations, layer);
}
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
activations.ffw_out);
ResidualConnection(activations.ffw_out, activations.x, layer,
/*is_attention=*/false);
}
// Returns the scale value to use for the embedding (basically sqrt model_dim).
static float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
}
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
//
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
// spec) until we run out of image tokens. This allows for a multi-image prompt
// if -2 locations with appropriate begin/end image tokens are created by the
// calling application.
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
// Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
return image_token_position + 1;
}
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
return image_token_position;
}
const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim);
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) {
// Using `Stride` to compute the offset works for both NUQ (because we use
// an offset and NUQ is never padded) and padded, because non-NUQ types are
// seekable, hence the offset can also skip any padding.
const size_t embedding_ofs = token * weights_t->Stride();
HWY_ASSERT(weights_t->Cols() == model_dim);
const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
});
if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
}
return image_token_position;
}
// Prefill() and Transformer() increment positions in-place.
using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time.
static HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
HWY_DASSERT(kv_caches.size() == num_queries);
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
// decode depends on the previous output token. However, each prefill batch of
// a query requires that preceding batches already wrote to the KV cache,
// hence we sequentially loop over token batches. We can reduce the number of
// iterations by increasing the batch size, but this also increases arithmetic
// intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
// Single query at a time, so pass slices of the spans because
// GemmaAttention will only access the first KV cache and position.
QueriesPos single_query_pos(&queries_pos[qi], 1);
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
const size_t prompt_size = queries_prompt[qi].size();
// In autoregressive mode, we don't need to prefill the last token, so - 1.
size_t prefill_this_query = prompt_size - 1;
const size_t prefix_end_this_query = queries_prefix_end[qi];
// We can't attend beyond the prompt_size.
HWY_ASSERT(prefix_end_this_query <= prompt_size);
// Special case: if the prefix includes the last token, we need to prefill
// the last token, too. However, we need to rewind this for the generation
// of the first token. So we need to keep track of this.
// TODO: consider implementing masking instead of this logic?
const bool attend_to_last_token =
(prefill_this_query < prefix_end_this_query);
if (attend_to_last_token) {
// The difference can be at most 1.
prefill_this_query += 1;
HWY_ASSERT(prefill_this_query == prefix_end_this_query);
}
// In prefix-LM mode, we need to look at all the tokens for the prefix in
// one iteration through the layers, so we need a large enough batch size.
HWY_ASSERT(prefix_end_this_query == 0 ||
max_tbatch_size >= prefill_this_query);
// For each batch of tokens in the query:
for (size_t tbatch_start = 0; tbatch_start < prefill_this_query;
tbatch_start += max_tbatch_size) {
const size_t tbatch_size =
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
activations.SetBatchSize(tbatch_size);
// Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0;
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, single_kv_cache);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
} else {
// The last token will be streamed later and we should only get here
// if we need to attend to the last token because it is in the prefix.
HWY_ASSERT(attend_to_last_token);
}
}
queries_pos[qi] += tbatch_size;
} // for tbatch_start
if (attend_to_last_token) {
// We need to rewind the position for the last token that we only
// attended to to make sure the prefix LM sees everything.
// This means we duplicate work on the last prompt token in autoregressive
// decoding. Alternatives: (1) real masking; (2) always prefill the last
// token and only generate the next one from the already prefilled
// activations.
queries_pos[qi] -= 1;
}
}
}
// Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence.
static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
if (layers_output) {
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
const float token_f = queries_token[query_idx];
layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f,
1);
}
}
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
/*pos_in_prompt=*/0, config, weights, activations.x);
}
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, kv_caches);
if (activations_observer) {
activations_observer(queries_pos, layer_idx, activations);
}
}
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (activations_observer) {
activations_observer(queries_pos, -1, activations);
}
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
queries_pos[query_idx] += 1;
}
}
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}
}
HWY_ASSERT(prompt_size > 0);
}
// Holds "is at end of stream" state for each query.
class TokenStreamer {
public:
TokenStreamer(const RuntimeConfig& runtime_config,
const ModelConfig& model_config)
: runtime_config_(runtime_config), model_config_(model_config) {}
// Returns whether the query was already at, or has just reached, the end of
// the stream: either via token == eos_id, or StreamToken returning false.
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
model_config_.IsEOS(token)) {
is_eos_.Set(query_idx);
return true;
}
return false;
}
private:
const RuntimeConfig& runtime_config_;
const ModelConfig& model_config_;
hwy::BitSet4096<> is_eos_;
};
// Runs one decode step for all the queries in the batch. Returns true if all
// queries are at <end_of_sentence>.
static bool DecodeStepT(const ModelConfig& config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const size_t query_idx_start, const KVCaches& kv_caches,
const QueriesPos& queries_prefix_end,
const hwy::Divisor div_seq_len, const size_t vocab_size,
const SampleFunc& sample_token,
Activations& activations, TokenStreamer& token_streamer,
std::vector<int>& gen_tokens, TimingInfo& timing_info,
const QueriesMutablePos& queries_mutable_pos) {
const size_t num_queries = queries_prompt.size();
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
queries_prefix_end, config, weights, activations, div_seq_len,
kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
HWY_DASSERT(num_queries == activations.x.Rows());
bool all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
MaybeLogitsSoftCap(config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated();
const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token;
}
return all_queries_eos;
}
static HWY_INLINE SampleFunc
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
// Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size);
};
}
// General case: Softmax with top-k sampling.
return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
};
}
// Returns the min and max number of tokens for all queries.
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
size_t max_prompt_size = 0;
for (size_t i = 0; i < queries_prompt.size(); ++i) {
max_prompt_size = HWY_MAX(max_prompt_size, queries_prompt[i].size());
}
return max_prompt_size;
}
// Generates one continuation for each query in `queries_prompt`, which is one
// qbatch whose size is at most the `batch_size` passed to
// `activations.Allocate`.
//
// `queries_pos` stores the KV cache position for each query. In the first turn
// of a chat, pos = 0; we increment each query's position after each token.
//
// `query_idx_start` is the query_idx of the first query in the batch, so that
// `StreamFunc` gets the global query index, not relative to the batch.
//
// `kv_caches` is for the batch, size must match `queries_prompt`.
static void GenerateT(const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in,
const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) {
if (queries_pos_in[i] == 0) {
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
}
}
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
}
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(queries_pos_in.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks(config, max_generated_tokens, max_prompt_size);
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt
// token is the first input token for generation.
timing_info.prefill_start = hwy::platform::Now();
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, config, weights, activations, runtime_config,
div_seq_len, kv_caches);
// Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefilled_tokens += queries_prompt[qi].size() - 1;
}
timing_info.NotifyPrefill(prefilled_tokens);
// queries_pos are incremented by Prefill.
activations.SetBatchSize(num_queries);
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens.
TokenStreamer token_streamer(runtime_config, config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
size_t last_token_pos_in_prompt =
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
(void)token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], gen_tokens[query_idx],
0.0f);
}
{
const size_t vocab_size = config.vocab_size;
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
bool all_queries_eos = DecodeStepT(
config, weights, runtime_config, queries_prompt, query_idx_start,
kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token,
activations, token_streamer, gen_tokens, timing_info,
queries_mutable_pos);
if (all_queries_eos) break;
} // foreach token to generate
timing_info.NotifyGenerateDone();
}
}
void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
const size_t max_batch_size =
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
// TODO: move into Gemma?
Activations activations(config, max_batch_size, env);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT(config, weights, activations, runtime_config, queries_prompt,
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
timing_info);
}
void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, MatMulEnv* env,
TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
// Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT(config, weights, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
timing_info);
}
}
void GenerateImageTokensT(const ModelConfig& config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
MatMulEnv* env) {
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = GetVitConfig(config);
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(GenerateSingleT);

View File

@ -160,9 +160,6 @@ class Gemma {
GemmaChatTemplate chat_template_;
};
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

193
gemma/griffin.cc Normal file
View File

@ -0,0 +1,193 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
size_t griffin_layer, Activations& activations,
const LayerWeightsPtrs* layer_weights,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_queries = queries_pos.size();
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
const size_t num_interleaved = num_tokens * num_queries;
// X / Y linear layers.
// TODO: MatMul
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
});
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 1 - 2 * l) * model_dim + i);
auto wv1 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 2 - 2 * l) * model_dim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
}
}
// RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
float* HWY_RESTRICT rnn_state =
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
TwoOfsMatVecAddLoop(
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
});
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0f);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
} // interleaved_idx
// Final linear layer.
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(),
*activations.env, activations.att_sums);
} // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

47
gemma/griffin.h Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
// Declares GriffinRecurrent for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
size_t griffin_layer, Activations& activations, \
const LayerWeightsPtrs* layer_weights, \
const KVCaches& kv_caches); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
#undef GEMMA_DECL_GRIFFIN
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_

339
gemma/vit.cc Normal file
View File

@ -0,0 +1,339 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "paligemma/image.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/vit.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/gemma-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Wrapper class; holds arguments in member variables to shorten call sites.
// The main differences to GemmaAttention are:
// - no KV Cache necessary, attention is always all-to-all and not causal.
// - no potential wrap-around, attention always goes from 0 to kSeqLen.
// - no need for batching, as we are always computing attention for kSeqLen
// tokens.
// This results in a much simpler implementation. However, to avoid duplicating
// code, we should still consider merging the two classes.
// TODO(keysers): Refactor to share code with GemmaAttention.
class VitAttention {
// Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV");
auto& qkv = activations_.q;
HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w,
layer_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, qkv);
}
// TODO(philculliton): transition fully to MatMul.
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Shift Q, K, VT to MatStorageT.
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
MatPadding::kPacked);
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim),
MatPadding::kPacked);
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
MatPadding::kPacked);
// Initialize att_out to zero prior to head loop.
ZeroInit(activations_.att_out);
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k =
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, *activations_.env, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols());
});
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
}
}
HWY_NOINLINE void DotSoftmaxWeightedSum() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Row(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
}
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_.vit.attn_out_b.PackedScale1();
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias,
*activations_.env, activations_.att_sums);
}
public:
VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations,
const LayerWeightsPtrs& layer)
: num_tokens_(num_tokens),
activations_(activations),
layer_(layer),
layer_config_(layer.layer_config),
pool_(activations.env->ctx.pools.Pool(0)) {}
HWY_INLINE void operator()() {
ComputeQKV();
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
DotSoftmaxWeightedSumMatrix();
} else {
DotSoftmaxWeightedSum();
}
SumHeads();
}
private:
const size_t num_tokens_;
Activations& activations_;
const LayerWeightsPtrs& layer_;
const LayerConfig& layer_config_;
hwy::ThreadPool& pool_;
};
// Same as FFWNoVit, but with different layer members and no second
// gating matrix.
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer) {
PROFILER_ZONE("Gen.FFW.ViT");
const LayerConfig& layer_config = layer.layer_config;
const bool add_bias = layer_config.ff_biases;
const float* bias1 = add_bias ? layer.vit.linear_0_b.PackedScale1() : nullptr;
const float* output_bias =
add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr;
// Compute the hidden layer activations.
CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1,
*activations.env, activations.C1);
// Activation (Gelu), store in C1.
ActivationBatched(layer_config.activation, activations.C1);
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias,
*activations.env, activations.ffw_out);
}
// Vit transformer layer. Some comments below refer to the Vit implementation in
// the Big Vision codebase. See
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
// try merging this with TransformerLayer.
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations) {
const size_t model_dim = activations.weights_config.model_dim;
auto type = layer.layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit);
(void)type;
(void)model_dim;
auto& x = activations.x;
HWY_DASSERT(x.Rows() == num_tokens);
HWY_DASSERT(x.Cols() == model_dim);
// y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out
LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias,
activations.pre_att_rms_out);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention(num_tokens, layer_idx, activations, layer)();
// x = out["+sa"] = x + y
AddFromBatched(activations.att_sums, x);
// y = nn.LayerNorm()(x)
// y ~ pre_ffw_rms_out
LayerNormBatched(x, layer.vit.layer_norm_1_scale, layer.vit.layer_norm_1_bias,
activations.pre_ffw_rms_out);
// y = out["mlp"] = MlpBlock(...)(y)
// y ~ ffw_out
FFWVit(activations, layer);
// x = out["+mlp"] = x + y
AddFromBatched(activations.ffw_out, x);
}
// Gets the patches of the image and embeds them with the image embedding
// kernel. The result is stored in activations.x.
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
Activations& activations) {
const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width;
const size_t seq_len = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
HWY_DASSERT(activations.x.Cols() == model_dim);
(void)model_dim;
// img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
MatPadding::kOdd);
for (size_t i = 0; i < seq_len; ++i) {
image.GetPatch(i, image_patches.Row(i));
}
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
activations.x);
// Add position embeddings.
CallUpcastedActivation(&weights.vit_img_pos_embedding,
[&](const auto* weights_t) {
AddFromBatched(*weights_t, activations.x);
});
}
// Prefills the image tokens with the ViT encoder.
void PrefillVit(const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = model_config.vit_config.seq_len;
const size_t vit_model_dim = model_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.Rows());
// Embed the image patches.
EmbedImagePatches(image, model_config, weights, activations);
// Go through all layers.
for (size_t layer_idx = 0;
layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) {
VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx),
activations);
}
// Final Layernorm.
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
weights.vit_encoder_norm_bias, activations.x);
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
vit_model_dim);
});
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
CallMatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
image_tokens);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

49
gemma/vit.h Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_
// Declares vision transformer FFW/Prefill for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void FFWVit(Activations& activations, const LayerWeightsPtrs& layer); \
\
void PrefillVit(const ModelConfig& model_config, \
const ModelWeightsPtrs& weights, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, Activations& activations); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_VIT)
#undef GEMMA_DECL_VIT
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_

View File

@ -448,8 +448,10 @@ static std::vector<IOBatch> MakeBatches(
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
}
offset += file_bytes_per_row;
// Must zero-initialize the in-memory row padding, see MatMul.
hwy::ZeroBytes(row_bytes + file_bytes_per_row,
mem_stride_bytes - file_bytes_per_row);
row_bytes += mem_stride_bytes;
// Keep the in-memory row padding uninitialized so msan detects any use.
}
HWY_ASSERT(offset == range.End());
}

View File

@ -93,6 +93,8 @@ class MatFinder {
// `WeightsOwner`.
struct LayerWeightsPtrs {
// Initializes tensor metadata without allocating.
// NOTE: do not store layer_idx, TransformerLayer and Attention may use
// other values for purposes of the KV cache.
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors)
: finder_(LayerSuffix(layer_idx), tensors),

View File

@ -1294,6 +1294,10 @@ struct MMImpl {
// are no other restrictions on shape, though performance is better when `M % 4
// == 0` or `M <= 4`.
//
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
// the behavior of `DecompressAndZeroPad`. We check this in debug builds.
//
// If `add` is non-null, the row-vector `add` is added to each of the `M` rows
// of `C`, which is a row-major matrix with arbitrary stride. A scale for
// `add` is not supported, so make sure its scale is 1.

View File

@ -27,6 +27,7 @@
#include <type_traits> // std::enable_if_t
#include <vector>
#include "ops/matmul.h"
#include "util/allocator.h"
#include "util/basics.h" // TokenAndProb
#include "util/mat.h"
@ -48,6 +49,7 @@
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul_static.h" // includes highway.h
#include "ops/sum-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/math-inl.h"
@ -57,6 +59,14 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TA, typename TC>
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
return CallUpcasted(
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
}
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
// casting prob from float to double just makes some changes to the
// exponent bias and pads zeros in the mantissa.

View File

@ -31,7 +31,7 @@
#include <random>
#include <vector>
#include "gemma/common.h" // ChooseQueryScale
#include "gemma/activations.h" // ChooseQueryScale
#include "util/allocator.h"
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT

View File

@ -54,10 +54,9 @@ void ZeroInit(MatPtr& mat) {
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
return;
}
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
for (size_t r = 0; r < mat.Rows(); ++r) {
hwy::ZeroBytes(mat.RowBytes(r), row_bytes);
}
// Also zero-initialize padding (required by MatMul).
hwy::ZeroBytes(mat.RowBytes(0),
mat.Stride() * mat.ElementBytes() * mat.Rows());
}
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,